This is an automated email from the ASF dual-hosted git repository. tallison pushed a commit to branch chardet-work in repository https://gitbox.apache.org/repos/asf/tika.git
commit dfb587c82d99d6d3de446021b1c6aa8412dd0538 Author: tballison <[email protected]> AuthorDate: Fri Feb 27 16:11:26 2026 -0500 TIKA-4662: train initial chardetect model (v1, 32 classes, 65536 buckets) 32 ML classes: excludes US-ASCII, HZ, ISO-2022-JP/KR/CN (structural gates). Training: 605K samples, 5 epochs, sparse SGD (~55s). In-sample accuracy: 84.6% strict / 92.5% soft. Sparse SGD: ByteNgramFeatureExtractor.extractSparseInto() is O(probe-length) not O(numBuckets) — ~130x speedup for 65536 buckets. STRUCTURAL_ONLY_CHARSETS in build_charset_training.py: skips train/ for zero-feature charsets, keeps devtest/test for full-pipeline eval. --- .gitignore | 2 + .../ml/chardetect/ByteNgramFeatureExtractor.java | 211 ++++++++++++ .../org/apache/tika/ml/chardetect/chardetect.bin | Bin 0 -> 2097769 bytes .../ml/chardetect/tools/TrainCharsetModel.java | 370 +++++++++++++++++++++ .../src/test/python/build_charset_training.py | 222 +++++++------ 5 files changed, 704 insertions(+), 101 deletions(-) diff --git a/.gitignore b/.gitignore index c8d1ac4bea..f74f9d7452 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,8 @@ dependency-reduced-pom.xml *.ipr *.iws *.bin +# Allow the bundled ML model resources +!**/src/main/resources/**/*.bin nbactions.xml nb-configuration.xml *.DS_Store diff --git a/tika-ml/tika-ml-chardetect/src/main/java/org/apache/tika/ml/chardetect/ByteNgramFeatureExtractor.java b/tika-ml/tika-ml-chardetect/src/main/java/org/apache/tika/ml/chardetect/ByteNgramFeatureExtractor.java new file mode 100644 index 0000000000..e72286411d --- /dev/null +++ b/tika-ml/tika-ml-chardetect/src/main/java/org/apache/tika/ml/chardetect/ByteNgramFeatureExtractor.java @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.tika.ml.chardetect; + +import org.apache.tika.ml.FeatureExtractor; + +/** + * Feature extractor for raw bytes for charset detection, using FNV-1a hashing + * into a fixed-width bucket array. + * + * <h3>Features emitted</h3> + * <ul> + * <li><strong>Unigrams</strong>: every byte {@code b} where + * {@code (b & 0xFF) >= 0x80}. These directly encode the high-byte + * frequency distribution that distinguishes single-byte encodings + * (KOI8-R vs Windows-1251 vs ISO-8859-2, etc.).</li> + * <li><strong>Bigrams</strong>: consecutive pairs {@code (b[i], b[i+1])} + * where {@code (b[i] & 0xFF) >= 0x80}. Anchoring on a high first byte + * captures multi-byte character structure (Big5, Shift-JIS, GBK, + * EUC-* lead/trail pairs) while automatically excluding ASCII-ASCII + * pairs produced by HTML tag markup — those bytes are all below 0x80 + * and carry no charset signal.</li> + * </ul> + * + * <h3>Why the high-byte filter matters</h3> + * <p>Training data is clean text (no HTML tags). Inference data is often raw + * HTML (many ASCII tag bytes). Without the filter, the model would see a + * different byte distribution at inference time than at training time. By + * ignoring bytes below 0x80 entirely, HTML tags are invisible to both the + * training and inference feature computation — no stripping needed.</p> + * + * <h3>No salting needed</h3> + * <p>Unigrams hash values {@code 0x0080–0x00FF}; bigrams anchored on a high + * first byte produce values {@code 0x8000–0xFFFF}. These ranges do not + * overlap, so unigrams and bigrams naturally occupy different regions of the + * hash space without an explicit salt.</p> + */ +public class ByteNgramFeatureExtractor implements FeatureExtractor<byte[]> { + + private static final int FNV_PRIME = 0x01000193; + private static final int FNV_OFFSET = 0x811c9dc5; + + private final int numBuckets; + + /** + * @param numBuckets number of hash buckets (feature-vector dimension). + * 2048 is a good default: large enough to limit collisions + * across the tens of thousands of active multi-byte bigrams, + * small enough that the model stays compact. + */ + public ByteNgramFeatureExtractor(int numBuckets) { + if (numBuckets <= 0) { + throw new IllegalArgumentException("numBuckets must be positive: " + numBuckets); + } + this.numBuckets = numBuckets; + } + + @Override + public int[] extract(byte[] input) { + int[] counts = new int[numBuckets]; + if (input == null || input.length == 0) { + return counts; + } + extractInto(input, 0, input.length, counts); + return counts; + } + + /** + * Extract from a sub-range of a byte array. + */ + public int[] extract(byte[] input, int offset, int length) { + int[] counts = new int[numBuckets]; + if (input == null || length == 0) { + return counts; + } + extractInto(input, offset, offset + length, counts); + return counts; + } + + /** + * Sparse extraction into caller-owned, reusable buffers. + * + * <p>This is O(probe length), not O(numBuckets), making it safe for large + * bucket counts (e.g. 65536) in tight training loops. + * + * <p>After the call, {@code dense[touched[0..n-1]]} hold the non-zero + * counts. The caller <em>must</em> clear those entries after use: + * <pre>{@code + * for (int i = 0; i < n; i++) dense[touched[i]] = 0; + * }</pre> + * This keeps the dense buffer zeroed for the next call without a full + * {@code Arrays.fill}. + * + * @param input raw bytes to extract features from + * @param dense caller-allocated scratch buffer of length {@code numBuckets} + * (must be all-zeros on entry; caller clears it after use) + * @param touched caller-allocated buffer; receives the indices of non-zero + * buckets (length {@code numBuckets} is a safe upper bound, + * but 2 * probe.length + 2 is the true worst case) + * @return number of active entries written into {@code touched} + */ + public int extractSparseInto(byte[] input, int[] dense, int[] touched) { + if (input == null || input.length == 0) { + return 0; + } + int n = 0; + for (int i = 0; i < input.length; i++) { + int bi = input[i] & 0xFF; + if (bi < 0x80) { + continue; + } + // Unigram + int h = (FNV_OFFSET ^ bi) * FNV_PRIME; + int b = (h & 0x7fffffff) % numBuckets; + if (dense[b] == 0) { + touched[n++] = b; + } + dense[b]++; + + // Bigram + if (i + 1 < input.length) { + int bi1 = input[i + 1] & 0xFF; + h = (FNV_OFFSET ^ bi) * FNV_PRIME; + h = (h ^ bi1) * FNV_PRIME; + b = (h & 0x7fffffff) % numBuckets; + if (dense[b] == 0) { + touched[n++] = b; + } + dense[b]++; + } + } + return n; + } + + private void extractInto(byte[] b, int from, int to, int[] counts) { + for (int i = from; i < to; i++) { + int bi = b[i] & 0xFF; + if (bi < 0x80) { + continue; // ASCII — no charset signal, skip + } + + // Unigram: hash the single high byte + int h = (FNV_OFFSET ^ bi) * FNV_PRIME; + counts[bucket(h)]++; + + // Bigram: anchor on this high byte, pair with whatever follows + if (i + 1 < to) { + int bi1 = b[i + 1] & 0xFF; + h = (FNV_OFFSET ^ bi) * FNV_PRIME; + h = (h ^ bi1) * FNV_PRIME; + counts[bucket(h)]++; + } + } + } + + private int bucket(int hash) { + return (hash & 0x7fffffff) % numBuckets; + } + + @Override + public int getNumBuckets() { + return numBuckets; + } + + /** + * Returns the fraction of bytes in {@code input} that are below 0x80 and + * therefore contribute <em>no features</em> to this extractor. + * + * <p>This is the byte-level analogue of the word-level OOV ("out of + * vocabulary") rate used in language-detection evaluation: a high ratio + * means the sample is essentially pure ASCII and the model has nothing to + * distinguish it from any other encoding.</p> + * + * <p>Thresholds used by the training-data pipeline to filter out low-signal + * chunks ({@code build_charset_training.py}): + * <ul> + * <li>CJK multibyte encodings: OOV > 0.80 (i.e. high-byte ratio < 0.20)</li> + * <li>SBCS / other legacy encodings: OOV > 0.98 (high-byte ratio < 0.02)</li> + * <li>ASCII / ISO-2022 / UTF-16 / UTF-32: exempt (by design)</li> + * </ul> + * + * @return value in [0.0, 1.0]; 1.0 means all bytes are ASCII (fully OOV), + * 0.0 means all bytes are high bytes. + */ + public static double oovRate(byte[] input) { + if (input == null || input.length == 0) { + return 1.0; + } + int ascii = 0; + for (byte b : input) { + if ((b & 0xFF) < 0x80) { + ascii++; + } + } + return (double) ascii / input.length; + } +} diff --git a/tika-ml/tika-ml-chardetect/src/main/resources/org/apache/tika/ml/chardetect/chardetect.bin b/tika-ml/tika-ml-chardetect/src/main/resources/org/apache/tika/ml/chardetect/chardetect.bin new file mode 100644 index 0000000000..1f5edecfb6 Binary files /dev/null and b/tika-ml/tika-ml-chardetect/src/main/resources/org/apache/tika/ml/chardetect/chardetect.bin differ diff --git a/tika-ml/tika-ml-chardetect/src/test/java/org/apache/tika/ml/chardetect/tools/TrainCharsetModel.java b/tika-ml/tika-ml-chardetect/src/test/java/org/apache/tika/ml/chardetect/tools/TrainCharsetModel.java new file mode 100644 index 0000000000..bd9fe9546c --- /dev/null +++ b/tika-ml/tika-ml-chardetect/src/test/java/org/apache/tika/ml/chardetect/tools/TrainCharsetModel.java @@ -0,0 +1,370 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.tika.ml.chardetect.tools; + +import java.io.DataInputStream; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.zip.GZIPInputStream; + +import org.apache.tika.ml.LinearModel; +import org.apache.tika.ml.chardetect.ByteNgramFeatureExtractor; +import org.apache.tika.ml.chardetect.CharsetConfusables; + +/** + * Trains a {@link LinearModel} for charset detection from the binary training + * data produced by {@code build_charset_training.py}. + * <p> + * <strong>Usage:</strong> + * <pre> + * java TrainCharsetModel \ + * --data /path/to/chardet-training \ + * --output chardetect.bin \ + * [--buckets 65536] \ + * [--epochs 3] \ + * [--lr 0.05] \ + * [--max-samples-per-class 500000] + * </pre> + * <p> + * Training file format (per charset, one {@code <charset>.bin.gz} file): + * <pre> + * [uint16 length][bytes of that length] + * ... repeated ... + * </pre> + * Each record is a raw byte chunk that should be classified as the charset + * named by the filename. + */ +public class TrainCharsetModel { + + private static final int DEFAULT_NUM_BUCKETS = 2048; + private static final int DEFAULT_EPOCHS = 3; + private static final float DEFAULT_LR = 0.05f; + private static final int DEFAULT_MAX_SAMPLES = 500_000; + + public static void main(String[] args) throws IOException { + Path dataDir = null; + Path outputPath = Paths.get("chardetect.bin"); + int numBuckets = DEFAULT_NUM_BUCKETS; + int epochs = DEFAULT_EPOCHS; + float lr = DEFAULT_LR; + int maxSamplesPerClass = DEFAULT_MAX_SAMPLES; + + for (int i = 0; i < args.length; i++) { + switch (args[i]) { + case "--data": + dataDir = Paths.get(args[++i]); + break; + case "--output": + outputPath = Paths.get(args[++i]); + break; + case "--buckets": + numBuckets = Integer.parseInt(args[++i]); + break; + case "--epochs": + epochs = Integer.parseInt(args[++i]); + break; + case "--lr": + lr = Float.parseFloat(args[++i]); + break; + case "--max-samples-per-class": + maxSamplesPerClass = Integer.parseInt(args[++i]); + break; + default: + System.err.println("Unknown argument: " + args[i]); + System.exit(1); + } + } + + if (dataDir == null) { + System.err.println("Usage: TrainCharsetModel --data <dir> [options]"); + System.exit(1); + } + + // Discover charset files + List<Path> charsetFiles = Files.list(dataDir) + .filter(p -> p.getFileName().toString().endsWith(".bin.gz")) + .sorted() + .collect(Collectors.toList()); + + if (charsetFiles.isEmpty()) { + System.err.println("No .bin.gz files found in: " + dataDir); + System.exit(1); + } + + // Build class label list (charset names stripped of .bin.gz) + String[] labels = charsetFiles.stream() + .map(p -> p.getFileName().toString() + .replaceAll("\\.bin\\.gz$", "")) + .toArray(String[]::new); + int numClasses = labels.length; + + System.out.println("Classes (" + numClasses + "): " + Arrays.toString(labels)); + System.out.println("Buckets: " + numBuckets + ", epochs: " + epochs + + ", lr: " + lr + ", max-samples/class: " + maxSamplesPerClass); + + ByteNgramFeatureExtractor extractor = + new ByteNgramFeatureExtractor(numBuckets); + + // Build class index map + Map<String, Integer> labelIndex = new HashMap<>(); + for (int i = 0; i < numClasses; i++) { + labelIndex.put(labels[i], i); + } + + // Load samples (up to maxSamplesPerClass per class) + System.out.println("Loading training data ..."); + List<byte[]>[] samplesPerClass = new List[numClasses]; + long totalSamples = 0; + for (int ci = 0; ci < numClasses; ci++) { + samplesPerClass[ci] = loadSamples(charsetFiles.get(ci), maxSamplesPerClass); + totalSamples += samplesPerClass[ci].size(); + System.out.printf(java.util.Locale.ROOT, " %-30s %,d samples%n", labels[ci], samplesPerClass[ci].size()); + } + System.out.printf(java.util.Locale.ROOT, "Total training samples: %,d%n", totalSamples); + + // SGD training: multinomial logistic regression with L2 regularisation + // Weight matrix: [numClasses][numBuckets] + float[][] weights = new float[numClasses][numBuckets]; + float[] biases = new float[numClasses]; + float lambda = 1e-5f; // L2 regularisation coefficient + + // Build a shuffled training index: list of (classIndex, sampleIndex) pairs + List<int[]> index = new ArrayList<>((int) Math.min(totalSamples, Integer.MAX_VALUE)); + for (int ci = 0; ci < numClasses; ci++) { + for (int si = 0; si < samplesPerClass[ci].size(); si++) { + index.add(new int[]{ci, si}); + } + } + + // Build int[][] group indices for probability collapsing and per-charset eval + int[][] groupIndices = CharsetConfusables.buildGroupIndices(labels); + + // Reusable sparse-extraction buffers (avoids per-sample allocation) + int[] denseScratch = new int[numBuckets]; + int[] touched = new int[numBuckets]; // worst-case size + + for (int epoch = 0; epoch < epochs; epoch++) { + Collections.shuffle(index); + long strictCorrect = 0; + float lossSum = 0f; + int count = 0; + + for (int[] pair : index) { + int trueClass = pair[0]; + byte[] sample = samplesPerClass[trueClass].get(pair[1]); + + // Sparse extraction: O(probeLength), not O(numBuckets) + int nActive = extractor.extractSparseInto(sample, denseScratch, touched); + + // Forward pass: only iterate active buckets + float[] logits = new float[numClasses]; + for (int c = 0; c < numClasses; c++) { + float dot = biases[c]; + for (int t = 0; t < nActive; t++) { + dot += weights[c][touched[t]] * denseScratch[touched[t]]; + } + logits[c] = dot; + } + float[] probs = LinearModel.softmax(logits.clone()); + + // Cross-entropy loss on the true class + lossSum -= (float) Math.log(Math.max(probs[trueClass], 1e-12f)); + + if (argmax(probs) == trueClass) { + strictCorrect++; + } + + // Backward pass: gradient = probs - one_hot(trueClass) + float[] grad = probs.clone(); + grad[trueClass] -= 1f; + + // Sparse SGD update with lazy L2: only active buckets are touched. + // Inactive weights start at 0 and are never pushed away without a + // gradient, so skipping their L2 decay is correct. + for (int c = 0; c < numClasses; c++) { + float g = grad[c]; + biases[c] -= lr * g; + for (int t = 0; t < nActive; t++) { + int b = touched[t]; + weights[c][b] -= lr * (g * denseScratch[b] + lambda * weights[c][b]); + } + } + count++; + + // Clear only the active entries (O(nActive), not O(numBuckets)) + for (int t = 0; t < nActive; t++) { + denseScratch[touched[t]] = 0; + } + } + + System.out.printf(java.util.Locale.ROOT, + "Epoch %d/%d loss=%.4f strict-acc=%.2f%%%n", + epoch + 1, epochs, lossSum / count, + 100.0 * strictCorrect / count); + } + + // Quantize and save + System.out.println("Quantizing ..."); + String[] qLabels = labels; + float[] qScales = new float[numClasses]; + float[] qBiases = biases; + byte[][] qWeights = new byte[numClasses][numBuckets]; + + for (int c = 0; c < numClasses; c++) { + float maxAbs = 1e-6f; + for (float w : weights[c]) { + float abs = Math.abs(w); + if (abs > maxAbs) { + maxAbs = abs; + } + } + qScales[c] = maxAbs / 127f; + for (int b = 0; b < numBuckets; b++) { + int q = Math.round(weights[c][b] / qScales[c]); + qWeights[c][b] = (byte) Math.max(-127, Math.min(127, q)); + } + } + + LinearModel model = new LinearModel(numBuckets, numClasses, + qLabels, qScales, qBiases, qWeights); + + try (OutputStream os = new FileOutputStream(outputPath.toFile())) { + model.save(os); + } + System.out.println("Model saved to: " + outputPath); + + // Per-charset evaluation on the training data (in-sample, sanity check) + System.out.println("\nPer-charset evaluation (quantized model, training data):"); + evaluatePerCharset(model, extractor, samplesPerClass, labels, groupIndices); + + System.out.println("Done."); + } + + private static List<byte[]> loadSamples(Path file, int maxSamples) throws IOException { + List<byte[]> samples = new ArrayList<>(); + try (InputStream fis = new FileInputStream(file.toFile()); + GZIPInputStream gis = new GZIPInputStream(fis); + DataInputStream dis = new DataInputStream(gis)) { + while (samples.size() < maxSamples) { + int len; + try { + len = dis.readUnsignedShort(); + } catch (java.io.EOFException e) { + break; + } + byte[] chunk = new byte[len]; + dis.readFully(chunk); + samples.add(chunk); + } + } + return samples; + } + + /** + * Evaluate the quantized model on the training samples and print a table + * showing per-charset strict accuracy and lenient accuracy (confusable group + * match counts as correct). + */ + private static void evaluatePerCharset( + LinearModel model, + ByteNgramFeatureExtractor extractor, + List<byte[]>[] samplesPerClass, + String[] labels, + int[][] groupIndices) { + + int numClasses = labels.length; + int[] strictCorrect = new int[numClasses]; + int[] lenientCorrect = new int[numClasses]; + int[] totals = new int[numClasses]; + + for (int trueClass = 0; trueClass < numClasses; trueClass++) { + for (byte[] sample : samplesPerClass[trueClass]) { + int[] features = extractor.extract(sample); + float[] probs = CharsetConfusables.collapseGroups( + model.predict(features), groupIndices); + int predicted = argmax(probs); + totals[trueClass]++; + if (predicted == trueClass) { + strictCorrect[trueClass]++; + lenientCorrect[trueClass]++; + } else if (CharsetConfusables.isLenientMatch( + labels[trueClass], labels[predicted])) { + lenientCorrect[trueClass]++; + } + } + } + + // Print table + int maxLabelLen = 0; + for (String l : labels) { + maxLabelLen = Math.max(maxLabelLen, l.length()); + } + String fmt = " %-" + maxLabelLen + "s %7d %7.2f%% %7.2f%% %s%n"; + System.out.printf(java.util.Locale.ROOT, + " %-" + maxLabelLen + "s %7s %8s %8s%n", + "Charset", "N", "Strict", "Soft"); + System.out.println(" " + "-".repeat(maxLabelLen + 32)); + + long totalStrict = 0; + long totalLenient = 0; + long totalN = 0; + for (int c = 0; c < numClasses; c++) { + int n = totals[c]; + if (n == 0) { + continue; + } + double strict = 100.0 * strictCorrect[c] / n; + double lenient = 100.0 * lenientCorrect[c] / n; + // Flag rows where lenient > strict (confusable errors are happening) + String flag = (lenientCorrect[c] > strictCorrect[c]) ? "*" : ""; + System.out.printf(java.util.Locale.ROOT, fmt, + labels[c], n, strict, lenient, flag); + totalStrict += strictCorrect[c]; + totalLenient += lenientCorrect[c]; + totalN += n; + } + System.out.println(" " + "-".repeat(maxLabelLen + 32)); + System.out.printf(java.util.Locale.ROOT, fmt, + "OVERALL", totalN, + 100.0 * totalStrict / totalN, + 100.0 * totalLenient / totalN, ""); + System.out.println(" (* = confusable-group errors present; lenient > strict)"); + } + + private static int argmax(float[] arr) { + int best = 0; + for (int i = 1; i < arr.length; i++) { + if (arr[i] > arr[best]) { + best = i; + } + } + return best; + } +} diff --git a/tika-ml/tika-ml-chardetect/src/test/python/build_charset_training.py b/tika-ml/tika-ml-chardetect/src/test/python/build_charset_training.py index a967abcf48..1435954a70 100644 --- a/tika-ml/tika-ml-chardetect/src/test/python/build_charset_training.py +++ b/tika-ml/tika-ml-chardetect/src/test/python/build_charset_training.py @@ -97,6 +97,7 @@ import os import random import struct import sys +import unicodedata from collections import defaultdict from pathlib import Path @@ -238,6 +239,20 @@ LANG_CHARSETS: dict[str, list[str]] = { # UTF-* added for all MADLAD languages below UNICODE_CHARSETS = ["UTF-8", "UTF-16-LE", "UTF-16-BE", "UTF-32-LE", "UTF-32-BE"] +# These charsets produce zero high bytes (all content < 0x80), so the ML +# feature extractor sees no signal. They are detected by structural gates +# in MlEncodingDetector before the model is ever called: +# - US-ASCII → checkAscii() → returns UTF-8 +# - HZ → checkHz() → returns HZ +# - ISO-2022* → detectIso2022() → returns ISO-2022-JP/KR/CN +# +# We still generate devtest/test files so EvalCharsetDetectors can measure +# full-pipeline accuracy (with structural gates enabled), but we skip train +# files so these classes don't pollute the ML model with zero-feature samples. +STRUCTURAL_ONLY_CHARSETS = frozenset({ + "US-ASCII", "HZ", "ISO-2022-JP", "ISO-2022-KR", "ISO-2022-CN", +}) + # Charsets sourced from Flores Traditional Chinese. # MADLAD-400 has no Cantonese/Min Nan in Han script. # *** devtest and test for these charsets are IN-SAMPLE copies of train. *** @@ -324,6 +339,11 @@ MIN_HIGH_BYTE_SBCS = 0.02 RTL_CHARSETS = {"IBM424-rtl", "IBM420-rtl"} +# CP1258 (Vietnamese) uses combining diacritical marks (NFD-style). +# MADLAD is NFC, so we must decompose before encoding and recompose before +# the drop-count comparison. +NFD_CHARSETS = {"windows-1258"} + # Maximum number of characters that may be silently dropped during encoding # (unencodable chars removed by errors='ignore') plus corrupt sequences # (U+FFFD produced on decode) before a sentence is rejected. @@ -354,18 +374,23 @@ def encode_chunk(text: str, charset: str, codec: str, if charset in RTL_CHARSETS: text = text[::-1] + # CP1258 needs NFD decomposition; compare after recomposing back to NFC + norm_text = unicodedata.normalize("NFD", text) if charset in NFD_CHARSETS else text + try: - encoded_full = text.encode(codec, errors="ignore") + encoded_full = norm_text.encode(codec, errors="ignore") except LookupError: return None if not encoded_full: return None try: - decoded = encoded_full.decode(codec, errors="replace") + decoded_raw = encoded_full.decode(codec, errors="replace") except (UnicodeDecodeError, LookupError): return None + # Recompose decoded output for fair character-count comparison + decoded = unicodedata.normalize("NFC", decoded_raw) if charset in NFD_CHARSETS else decoded_raw n_dropped = len(text) - len(decoded) n_corrupt = decoded.count("\ufffd") if n_dropped + n_corrupt > MAX_DROPPED_CHARS: @@ -548,52 +573,9 @@ def main(): print("fastText filter: disabled (model not found)") # ----------------------------------------------------------------------- - # Step 1: Build per-language sentence pools from MADLAD - # Only load languages that contribute to the requested (or all) charsets. - # ----------------------------------------------------------------------- - print("\n=== Loading and splitting MADLAD sentences ===") - lang_pools: dict[str, dict[str, list[str]]] = {} - - # Determine which charsets we're actually building - target_charsets: set[str] = set(args.charsets) if args.charsets else ( - set(CHARSET_CODEC.keys()) - ) - # Which languages contribute to those charsets? - relevant_langs: set[str] = set() - for lang, charsets in LANG_CHARSETS.items(): - lang_charsets = set(UNICODE_CHARSETS + charsets) - if lang_charsets & target_charsets: - relevant_langs.add(lang) - - for lang in sorted(relevant_langs): - lang_dir = args.madlad_dir / lang - if not lang_dir.is_dir(): - print(f" SKIP {lang}: directory not found") - continue - - raw = load_madlad_sentences(lang_dir, args.max_source_per_lang) - if not raw: - print(f" SKIP {lang}: no sentences") - continue - - # fastText contamination filter (batched) - before = len(raw) - if ft_model is not None: - raw = apply_fasttext_filter(raw, lang, ft_model) - removed = before - len(raw) - - pools = split_pool(raw, args.seed) - lang_pools[lang] = pools - print(f" {lang:6s}: {before:>7,} sentences, " - f"removed {removed:>5,} contaminated, " - f"split {len(pools['train']):>6,} / " - f"{len(pools['devtest']):>5,} / " - f"{len(pools['test']):>5,}", flush=True) - - # ----------------------------------------------------------------------- - # Step 2: Load Flores Traditional Chinese (for Big5, Big5-HKSCS, EUC-TW) + # Step 1: Load Flores Traditional Chinese (tiny — ~4K sentences — load once) # ----------------------------------------------------------------------- - print("\n=== Loading Flores Traditional Chinese ===") + print("=== Loading Flores Traditional Chinese ===") flores_sentences = load_flores_sentences(args.flores_dir) rng_flores = random.Random(args.seed) rng_flores.shuffle(flores_sentences) @@ -605,27 +587,17 @@ def main(): caps = {"train": args.train_cap, "devtest": args.devtest_cap, "test": args.test_cap} - # Determine which charsets to build and which are in-sample - all_charsets: set[str] = set() - charset_is_in_sample: dict[str, bool] = {} - for lang in lang_pools: - for cs in UNICODE_CHARSETS + LANG_CHARSETS.get(lang, []): - all_charsets.add(cs) - charset_is_in_sample[cs] = False - for cs in FLORES_CHARSETS: - if flores_sentences: - all_charsets.add(cs) - charset_is_in_sample[cs] = True - - if args.charsets: - all_charsets &= set(args.charsets) - all_charsets = {cs for cs in all_charsets + # Determine which charsets to build + target_charsets: set[str] = set(args.charsets) if args.charsets else ( + set(CHARSET_CODEC.keys()) + ) + all_charsets = {cs for cs in target_charsets if cs in CHARSET_CODEC and probe_codec(cs, CHARSET_CODEC[cs])} + flores_charsets = {cs for cs in FLORES_CHARSETS if flores_sentences} & all_charsets # ----------------------------------------------------------------------- - # Step 3 + 4: For each charset, build sentence pool on-the-fly and encode. - # We process one charset at a time to avoid holding millions of sentences - # in memory simultaneously. + # Step 2 + 3: For each charset, load only the languages it needs, encode, + # then free memory. Never hold more than one charset's sentence pool in RAM. # ----------------------------------------------------------------------- print(f"\n=== Encoding {len(all_charsets)} charsets ===") print(f" Caps: train={args.train_cap:,} " @@ -638,50 +610,98 @@ def main(): for charset in sorted(all_charsets): codec = CHARSET_CODEC[charset] - in_sample = charset_is_in_sample.get(charset, False) + in_sample = charset in flores_charsets flag = " *** IN-SAMPLE ***" if in_sample else "" print(f" {charset}{flag}", flush=True) split_counts: dict[str, int] = {} - for split in split_names: - cap = caps[split] - - # Collect sentences from contributing languages, capped per lang - # to avoid memory blowup and give each language a fair share. - if in_sample: - # Flores: all sentences go to every split (in-sample) - sents = list(flores_sentences) - else: - contributing = [ - lang for lang in lang_pools - if charset in (UNICODE_CHARSETS + LANG_CHARSETS.get(lang, [])) - ] - # Per-language cap: take at most 3× cap / num_langs from each, - # with a floor so small-language charsets still get enough data. - per_lang = max(200, (cap * 3) // max(1, len(contributing))) - sents = [] - for lang in contributing: - pool = lang_pools[lang][split] - sents.extend(pool[:per_lang]) - - # Shuffle to interleave languages - rng_split = random.Random(args.seed + hash(charset + split)) - rng_split.shuffle(sents) - - out_file = args.output_dir / split / f"{charset}.bin.gz" - n = build_split_file( - sents, charset, codec, out_file, - cap, args.min_chunk, args.max_chunk, - random.Random(args.seed + hash(charset + split + "enc")), + if in_sample: + # Flores charsets: same sentences go to all splits (in-sample eval) + for split in split_names: + out_file = args.output_dir / split / f"{charset}.bin.gz" + n = build_split_file( + list(flores_sentences), charset, codec, out_file, + caps[split], args.min_chunk, args.max_chunk, + random.Random(args.seed + hash(charset + split + "enc")), + ) + split_counts[split] = n + print(f" {split:7s}: {n:>6,} samples") + else: + # MADLAD charsets: determine contributing languages + is_unicode = charset in UNICODE_CHARSETS + contributing = sorted( + lang for lang, csets in LANG_CHARSETS.items() + if is_unicode or charset in csets + ) + + # Load sentences for each contributing language. + # Per-language load cap: we need roughly train_cap*6/n_langs sentences + # total to survive encoding rejections; the 80/10/10 split means + # train gets 80% of loaded data. Use max_source_per_lang as ceiling. + n_langs = max(1, len(contributing)) + load_cap = min( + args.max_source_per_lang, + max(5_000, (args.train_cap * 10) // n_langs), ) - split_counts[split] = n - print(f" {split:7s}: {n:>6,} samples") + + lang_pools: dict[str, dict[str, list[str]]] = {} + for lang in contributing: + lang_dir = args.madlad_dir / lang + if not lang_dir.is_dir(): + continue + raw = load_madlad_sentences(lang_dir, load_cap) + if not raw: + continue + before = len(raw) + if ft_model is not None: + raw = apply_fasttext_filter(raw, lang, ft_model) + lang_pools[lang] = split_pool(raw, args.seed) + removed = before - len(raw) + pools = lang_pools[lang] + print(f" load {lang}: {before:>6,} → split " + f"{len(pools['train']):>5,}/{len(pools['devtest']):>4,}" + f"/{len(pools['test']):>4,}" + + (f" (-{removed} filtered)" if removed else ""), + flush=True) + + structural_only = charset in STRUCTURAL_ONLY_CHARSETS + if structural_only: + print(f" (structural-only: skipping train, generating devtest/test only)") + + for split in split_names: + # Skip train split for structural-only charsets: the ML model + # sees zero features for these encodings; structural gates in + # MlEncodingDetector handle them before the model is called. + if structural_only and split == "train": + split_counts[split] = 0 + continue + + cap = caps[split] + per_lang = max(200, (cap * 4) // max(1, len(lang_pools))) + sents: list[str] = [] + for lang, pools in lang_pools.items(): + sents.extend(pools[split][:per_lang]) + rng_split = random.Random(args.seed + hash(charset + split)) + rng_split.shuffle(sents) + + out_file = args.output_dir / split / f"{charset}.bin.gz" + n = build_split_file( + sents, charset, codec, out_file, + cap, args.min_chunk, args.max_chunk, + random.Random(args.seed + hash(charset + split + "enc")), + ) + split_counts[split] = n + print(f" {split:7s}: {n:>6,} samples") + + # Release the sentence pools for this charset before moving on + del lang_pools manifest[charset] = { - "codec": codec, - "eval_is_in_sample": in_sample, - "samples": split_counts, + "codec": codec, + "eval_is_in_sample": in_sample, + "structural_only": charset in STRUCTURAL_ONLY_CHARSETS, + "samples": split_counts, } # -----------------------------------------------------------------------
