This is an automated email from the ASF dual-hosted git repository. huaxingao pushed a commit to branch index_poc in repository https://gitbox.apache.org/repos/asf/iceberg.git
commit de683c927d2a40c61a44cdc65e428d3c1f361fa1 Author: Huaxin Gao <[email protected]> AuthorDate: Thu Feb 12 22:35:45 2026 -0800 Bloom filter index POC --- .../spark/actions/BloomFilterIndexUtil.java | 695 ++++++++++++++++ .../actions/BuildBloomFilterIndexSparkAction.java | 205 +++++ .../apache/iceberg/spark/actions/SparkActions.java | 10 + .../iceberg/spark/source/SparkBatchQueryScan.java | 97 +++ .../actions/TestBuildBloomFilterIndexAction.java | 116 +++ .../benchmark/TestBloomFilterIndexBenchmark.java | 894 +++++++++++++++++++++ 6 files changed, 2017 insertions(+) diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/actions/BloomFilterIndexUtil.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/actions/BloomFilterIndexUtil.java new file mode 100644 index 0000000000..cf67d87d7e --- /dev/null +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/actions/BloomFilterIndexUtil.java @@ -0,0 +1,695 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.BitSet; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import org.apache.iceberg.ContentScanTask; +import org.apache.iceberg.PartitionScanTask; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.StatisticsFile; +import org.apache.iceberg.Table; +import org.apache.iceberg.io.InputFile; +import org.apache.iceberg.puffin.Blob; +import org.apache.iceberg.puffin.BlobMetadata; +import org.apache.iceberg.puffin.FileMetadata; +import org.apache.iceberg.puffin.Puffin; +import org.apache.iceberg.puffin.PuffinCompressionCodec; +import org.apache.iceberg.puffin.PuffinReader; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.spark.SparkTableUtil; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.Pair; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.functions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Minimal utilities for building and consulting a Bloom-filter-based file-skipping index stored in + * Puffin statistics files. + * + * <p>This is a proof-of-concept implementation intended to demonstrate that Bloom filters can + * improve query performance by pruning data files. It is intentionally limited in scope: + * + * <ul> + * <li>Only equality predicates on a single column are supported. + * <li>Only un-nested, primitive columns are supported. + * <li>Bloom filters are built per data file using Spark and stored in a single statistics file + * per snapshot. + * <li>The index is best-effort: if anything looks inconsistent or unsupported, callers should + * ignore it and fall back to normal planning. + * </ul> + */ +public class BloomFilterIndexUtil { + + private static final Logger LOG = LoggerFactory.getLogger(BloomFilterIndexUtil.class); + + // Blob type used for Bloom filters inside Puffin statistics files. + // Kept package-private so both actions and scan-side code can share it. + static final String BLOOM_FILTER_BLOB_TYPE = "bloom-filter-v1"; + + // Property keys on each Bloom blob + static final String PROP_DATA_FILE = "data-file"; + static final String PROP_COLUMN_NAME = "column-name"; + static final String PROP_FPP = "fpp"; + static final String PROP_NUM_VALUES = "num-values"; + static final String PROP_NUM_BITS = "num-bits"; + static final String PROP_NUM_HASHES = "num-hashes"; + + // Heuristic Bloom filter sizing for the POC + private static final double DEFAULT_FPP = 0.01; + private static final long DEFAULT_EXPECTED_VALUES_PER_FILE = 100_000L; + + private BloomFilterIndexUtil() {} + + /** + * Build Bloom-filter blobs for a single column of a snapshot and return them in memory. + * + * <p>This uses Spark to read the table at a given snapshot, groups by input file, and collects + * values for the target column on the driver to build per-file Bloom filters. It is suitable for + * small/medium demo tables, not for production-scale index building. + */ + static List<Blob> buildBloomBlobsForColumn( + SparkSession spark, Table table, Snapshot snapshot, String columnName) { + Preconditions.checkNotNull(snapshot, "snapshot must not be null"); + Preconditions.checkArgument( + columnName != null && !columnName.isEmpty(), "columnName must not be null/empty"); + + Dataset<Row> df = SparkTableUtil.loadTable(spark, table, snapshot.snapshotId()); + + // Attach per-row file path using Spark's built-in function. This relies on the underlying + // reader exposing file information, which is already true for Iceberg's Spark integration. + Column fileCol = functions.input_file_name().alias("_file"); + Dataset<Row> fileAndValue = + df.select(functions.col(columnName), fileCol).na().drop(); // drop nulls for simplicity + + Dataset<Row> perFileValues = + fileAndValue.groupBy("_file").agg(functions.collect_list(columnName).alias("values")); + + List<Row> rows = perFileValues.collectAsList(); + LOG.info( + "Building Bloom-filter blobs for column {} in table {} (snapshot {}, {} file group(s))", + columnName, + table.name(), + snapshot.snapshotId(), + rows.size()); + + Schema schema = table.schemas().get(snapshot.schemaId()); + Types.NestedField field = schema.findField(columnName); + Preconditions.checkArgument( + field != null, "Cannot find column %s in schema %s", columnName, schema); + Preconditions.checkArgument( + supportedFieldType(field.type()), + "Unsupported Bloom index column type %s for column %s (supported: string, int, long, uuid)", + field.type(), + columnName); + + List<Blob> blobs = Lists.newArrayListWithExpectedSize(rows.size()); + + for (Row row : rows) { + String filePath = row.getString(0); + @SuppressWarnings("unchecked") + List<Object> values = row.getList(1); + + if (values == null || values.isEmpty()) { + continue; + } + + SimpleBloomFilter bloom = + SimpleBloomFilter.create( + (int) Math.max(DEFAULT_EXPECTED_VALUES_PER_FILE, values.size()), DEFAULT_FPP); + + long nonNullCount = 0L; + for (Object value : values) { + if (value != null) { + byte[] canonicalBytes = canonicalBytes(value); + Preconditions.checkArgument( + canonicalBytes != null, + "Unsupported Bloom index value type %s for column %s", + value.getClass().getName(), + columnName); + bloom.put(canonicalBytes); + nonNullCount++; + } + } + + if (nonNullCount == 0L) { + continue; + } + + ByteBuffer serialized = serializeBloomBits(bloom); + + Map<String, String> properties = + ImmutableMap.of( + PROP_DATA_FILE, + filePath, + PROP_COLUMN_NAME, + columnName, + PROP_FPP, + String.valueOf(DEFAULT_FPP), + PROP_NUM_VALUES, + String.valueOf(nonNullCount), + PROP_NUM_BITS, + String.valueOf(bloom.numBits()), + PROP_NUM_HASHES, + String.valueOf(bloom.numHashFunctions())); + + Blob blob = + new Blob( + BLOOM_FILTER_BLOB_TYPE, + ImmutableList.of(field.fieldId()), + snapshot.snapshotId(), + snapshot.sequenceNumber(), + serialized, + PuffinCompressionCodec.ZSTD, + properties); + + blobs.add(blob); + } + + return blobs; + } + + private static ByteBuffer serializeBloomBits(SimpleBloomFilter bloom) { + return ByteBuffer.wrap(bloom.toBitsetBytes()); + } + + private static byte[] toByteArray(ByteBuffer buffer) { + byte[] bytes = new byte[buffer.remaining()]; + buffer.get(bytes); + return bytes; + } + + /** + * Use Bloom-filter blobs stored in statistics files to prune data files for an equality predicate + * on a single column. + * + * <p>The caller is responsible for: + * + * <ul> + * <li>Ensuring the predicate is an equality of a supported literal type. + * <li>Passing the snapshot used for planning (so we can ignore stale stats). + * <li>Falling back cleanly if this method returns the original tasks. + * </ul> + * + * @param table the Iceberg table + * @param snapshot the snapshot being scanned (may be null) + * @param tasksSupplier supplier that returns the already planned tasks (typically calls + * super.tasks()) + * @param columnName the column name used in the equality predicate + * @param literalValue the literal value used in the equality predicate + * @return either the original tasks or a filtered subset if Bloom pruning was applied + */ + public static <T extends PartitionScanTask> List<T> pruneTasksWithBloomIndex( + Table table, + Snapshot snapshot, + Supplier<List<T>> tasksSupplier, + String columnName, + Object literalValue) { + + if (snapshot == null) { + return tasksSupplier.get(); + } + + List<StatisticsFile> statsFilesForSnapshot = + table.statisticsFiles().stream() + .filter(sf -> sf.snapshotId() == snapshot.snapshotId()) + .collect(Collectors.toList()); + + if (statsFilesForSnapshot.isEmpty()) { + return tasksSupplier.get(); + } + + byte[] literalBytes = canonicalBytes(literalValue); + if (literalBytes == null) { + // Unsupported literal type for this portable encoding; do not prune. + return tasksSupplier.get(); + } + String columnNameLower = columnName.toLowerCase(Locale.ROOT); + + Set<String> candidateFiles = + loadCandidateFilesFromBloom(table, statsFilesForSnapshot, columnNameLower, literalBytes); + + if (candidateFiles == null) { + // Index missing/unusable; do not change planning. + return tasksSupplier.get(); + } + + if (candidateFiles.isEmpty()) { + // Bloom filters have no false negatives. If the index is usable but no files matched, we can + // safely prune to an empty scan without planning tasks. + return ImmutableList.of(); + } + + List<T> tasks = tasksSupplier.get(); + List<T> filtered = + tasks.stream() + .filter( + task -> { + if (task instanceof ContentScanTask) { + ContentScanTask<?> contentTask = (ContentScanTask<?>) task; + String path = contentTask.file().path().toString(); + return candidateFiles.contains(path); + } + // If we don't know how to interpret the task, keep it for safety. + return true; + }) + .collect(Collectors.toList()); + + if (filtered.size() == tasks.size()) { + // No pruning happened; return the original list to avoid surprising equals/hashCode behavior. + return tasks; + } + + LOG.info( + "Bloom index pruned {} of {} task(s) for table {} on column {} = {}", + tasks.size() - filtered.size(), + tasks.size(), + table.name(), + columnName, + literalValue); + + return filtered; + } + + private static Set<String> loadCandidateFilesFromBloom( + Table table, List<StatisticsFile> statsFiles, String columnNameLower, byte[] literalBytes) { + + Set<String> candidateFiles = Sets.newHashSet(); + boolean indexFound = false; + boolean indexUsable = false; + + for (StatisticsFile stats : statsFiles) { + InputFile inputFile = table.io().newInputFile(stats.path()); + + try (PuffinReader reader = + Puffin.read(inputFile).withFileSize(stats.fileSizeInBytes()).build()) { + FileMetadata fileMetadata = reader.fileMetadata(); + + List<BlobMetadata> bloomBlobs = + fileMetadata.blobs().stream() + .filter( + bm -> + BLOOM_FILTER_BLOB_TYPE.equals(bm.type()) + && columnMatches(bm, table, columnNameLower)) + .collect(Collectors.toList()); + + if (bloomBlobs.isEmpty()) { + continue; + } + + indexFound = true; + Iterable<Pair<BlobMetadata, ByteBuffer>> blobData = reader.readAll(bloomBlobs); + + for (Pair<BlobMetadata, ByteBuffer> pair : blobData) { + BlobMetadata bm = pair.first(); + ByteBuffer data = pair.second(); + + String dataFile = bm.properties().get(PROP_DATA_FILE); + if (dataFile == null) { + continue; + } + + Integer numBits = parsePositiveInt(bm.properties().get(PROP_NUM_BITS)); + Integer numHashes = parsePositiveInt(bm.properties().get(PROP_NUM_HASHES)); + if (numBits == null || numHashes == null) { + continue; + } + + indexUsable = true; + SimpleBloomFilter bloom = + SimpleBloomFilter.fromBitsetBytes(numBits, numHashes, toByteArray(data.duplicate())); + if (bloom.mightContain(literalBytes)) { + candidateFiles.add(dataFile); + } + } + + } catch (Exception e) { + LOG.warn( + "Failed to read Bloom index from statistics file {} for table {}, skipping it", + stats.path(), + table.name(), + e); + } + } + + if (!indexFound || !indexUsable) { + return null; + } + + return candidateFiles; + } + + private static Integer parsePositiveInt(String value) { + if (value == null) { + return null; + } + + try { + int parsed = Integer.parseInt(value); + return parsed > 0 ? parsed : null; + } catch (NumberFormatException e) { + return null; + } + } + + private static boolean columnMatches(BlobMetadata bm, Table table, String columnNameLower) { + List<Integer> fields = bm.inputFields(); + if (fields == null || fields.isEmpty()) { + return false; + } + + int fieldId = fields.get(0); + String colName = table.schema().findColumnName(fieldId); + return colName != null && colName.toLowerCase(Locale.ROOT).equals(columnNameLower); + } + + private static boolean supportedFieldType(Type type) { + return type.typeId() == Type.TypeID.STRING + || type.typeId() == Type.TypeID.INTEGER + || type.typeId() == Type.TypeID.LONG + || type.typeId() == Type.TypeID.UUID; + } + + /** + * Canonical bytes for Phase 1 portable encoding: + * + * <ul> + * <li>string: UTF-8 bytes + * <li>int: 4 bytes two's complement big-endian + * <li>long: 8 bytes two's complement big-endian + * <li>uuid: 16 bytes (MSB 8 bytes big-endian + LSB 8 bytes big-endian) + * </ul> + * + * <p>Returns null if the type is unsupported. + */ + private static byte[] canonicalBytes(Object value) { + if (value instanceof String) { + return ((String) value).getBytes(StandardCharsets.UTF_8); + } + + // Spark may use UTF8String in some paths; treat it as a string value for this encoding. + if (value instanceof org.apache.spark.unsafe.types.UTF8String) { + return value.toString().getBytes(StandardCharsets.UTF_8); + } + + if (value instanceof Integer) { + int intValue = (Integer) value; + return new byte[] { + (byte) (intValue >>> 24), (byte) (intValue >>> 16), (byte) (intValue >>> 8), (byte) intValue + }; + } + + if (value instanceof Long) { + long longValue = (Long) value; + return new byte[] { + (byte) (longValue >>> 56), + (byte) (longValue >>> 48), + (byte) (longValue >>> 40), + (byte) (longValue >>> 32), + (byte) (longValue >>> 24), + (byte) (longValue >>> 16), + (byte) (longValue >>> 8), + (byte) longValue + }; + } + + if (value instanceof UUID) { + UUID uuid = (UUID) value; + long msb = uuid.getMostSignificantBits(); + long lsb = uuid.getLeastSignificantBits(); + return new byte[] { + (byte) (msb >>> 56), + (byte) (msb >>> 48), + (byte) (msb >>> 40), + (byte) (msb >>> 32), + (byte) (msb >>> 24), + (byte) (msb >>> 16), + (byte) (msb >>> 8), + (byte) msb, + (byte) (lsb >>> 56), + (byte) (lsb >>> 48), + (byte) (lsb >>> 40), + (byte) (lsb >>> 32), + (byte) (lsb >>> 24), + (byte) (lsb >>> 16), + (byte) (lsb >>> 8), + (byte) lsb + }; + } + + return null; + } + + /** + * Minimal Bloom filter implementation for the POC using Murmur3 x64 128-bit and standard + * double-hashing to derive multiple hash functions. + */ + private static final class SimpleBloomFilter { + private final int numBits; + private final int numHashFunctions; + private final BitSet bits; + + private SimpleBloomFilter(int numBits, int numHashFunctions) { + this.numBits = numBits; + this.numHashFunctions = numHashFunctions; + this.bits = new BitSet(numBits); + } + + private SimpleBloomFilter(int numBits, int numHashFunctions, BitSet bits) { + this.numBits = numBits; + this.numHashFunctions = numHashFunctions; + this.bits = bits; + } + + static SimpleBloomFilter create(int expectedInsertions, double fpp) { + int numBits = Math.max(8 * expectedInsertions, 1); // very rough heuristic, 8 bits/value min + int numHashFunctions = + Math.max(2, (int) Math.round(-Math.log(fpp) / Math.log(2))); // ~ ln(1/fpp)/ln(2) + return new SimpleBloomFilter(numBits, numHashFunctions); + } + + static SimpleBloomFilter fromBitsetBytes( + int numBits, int numHashFunctions, byte[] bitsetBytes) { + int requiredBytes = (numBits + 7) / 8; + Preconditions.checkArgument( + bitsetBytes.length == requiredBytes, + "Invalid Bloom bitset length: expected %s bytes, got %s bytes", + requiredBytes, + bitsetBytes.length); + BitSet bits = BitSet.valueOf(bitsetBytes); + return new SimpleBloomFilter(numBits, numHashFunctions, bits); + } + + int numBits() { + return numBits; + } + + int numHashFunctions() { + return numHashFunctions; + } + + byte[] toBitsetBytes() { + int requiredBytes = (numBits + 7) / 8; + byte[] bytes = new byte[requiredBytes]; + byte[] encoded = bits.toByteArray(); // bit 0 is LSB of byte 0 + System.arraycopy(encoded, 0, bytes, 0, Math.min(encoded.length, bytes.length)); + return bytes; + } + + void put(byte[] valueBytes) { + long[] hashes = murmur3x64_128(valueBytes); + long hash1 = hashes[0]; + long hash2 = hashes[1]; + for (int i = 0; i < numHashFunctions; i++) { + long combined = hash1 + (long) i * hash2; + int index = (int) Long.remainderUnsigned(combined, (long) numBits); + bits.set(index); + } + } + + boolean mightContain(byte[] valueBytes) { + long[] hashes = murmur3x64_128(valueBytes); + long hash1 = hashes[0]; + long hash2 = hashes[1]; + for (int i = 0; i < numHashFunctions; i++) { + long combined = hash1 + (long) i * hash2; + int index = (int) Long.remainderUnsigned(combined, (long) numBits); + if (!bits.get(index)) { + return false; + } + } + return true; + } + + // MurmurHash3 x64 128-bit, seed=0. + // See https://github.com/aappleby/smhasher/blob/master/src/MurmurHash3.cpp + private static long[] murmur3x64_128(byte[] data) { + final int length = data.length; + final int nblocks = length / 16; + + long h1 = 0L; + long h2 = 0L; + + final long c1 = 0x87c37b91114253d5L; + final long c2 = 0x4cf5ad432745937fL; + + // body + for (int i = 0; i < nblocks; i++) { + int offset = i * 16; + long k1 = getLittleEndianLong(data, offset); + long k2 = getLittleEndianLong(data, offset + 8); + + k1 *= c1; + k1 = Long.rotateLeft(k1, 31); + k1 *= c2; + h1 ^= k1; + + h1 = Long.rotateLeft(h1, 27); + h1 += h2; + h1 = h1 * 5 + 0x52dce729; + + k2 *= c2; + k2 = Long.rotateLeft(k2, 33); + k2 *= c1; + h2 ^= k2; + + h2 = Long.rotateLeft(h2, 31); + h2 += h1; + h2 = h2 * 5 + 0x38495ab5; + } + + // tail + long k1 = 0L; + long k2 = 0L; + int tailStart = nblocks * 16; + switch (length & 15) { + case 15: + k2 ^= ((long) data[tailStart + 14] & 0xff) << 48; + // fall through + case 14: + k2 ^= ((long) data[tailStart + 13] & 0xff) << 40; + // fall through + case 13: + k2 ^= ((long) data[tailStart + 12] & 0xff) << 32; + // fall through + case 12: + k2 ^= ((long) data[tailStart + 11] & 0xff) << 24; + // fall through + case 11: + k2 ^= ((long) data[tailStart + 10] & 0xff) << 16; + // fall through + case 10: + k2 ^= ((long) data[tailStart + 9] & 0xff) << 8; + // fall through + case 9: + k2 ^= ((long) data[tailStart + 8] & 0xff); + k2 *= c2; + k2 = Long.rotateLeft(k2, 33); + k2 *= c1; + h2 ^= k2; + // fall through + case 8: + k1 ^= ((long) data[tailStart + 7] & 0xff) << 56; + // fall through + case 7: + k1 ^= ((long) data[tailStart + 6] & 0xff) << 48; + // fall through + case 6: + k1 ^= ((long) data[tailStart + 5] & 0xff) << 40; + // fall through + case 5: + k1 ^= ((long) data[tailStart + 4] & 0xff) << 32; + // fall through + case 4: + k1 ^= ((long) data[tailStart + 3] & 0xff) << 24; + // fall through + case 3: + k1 ^= ((long) data[tailStart + 2] & 0xff) << 16; + // fall through + case 2: + k1 ^= ((long) data[tailStart + 1] & 0xff) << 8; + // fall through + case 1: + k1 ^= ((long) data[tailStart] & 0xff); + k1 *= c1; + k1 = Long.rotateLeft(k1, 31); + k1 *= c2; + h1 ^= k1; + // fall through + default: + // no tail + } + + // finalization + h1 ^= length; + h2 ^= length; + + h1 += h2; + h2 += h1; + + h1 = fmix64(h1); + h2 = fmix64(h2); + + h1 += h2; + h2 += h1; + + return new long[] {h1, h2}; + } + + private static long getLittleEndianLong(byte[] data, int offset) { + return ((long) data[offset] & 0xff) + | (((long) data[offset + 1] & 0xff) << 8) + | (((long) data[offset + 2] & 0xff) << 16) + | (((long) data[offset + 3] & 0xff) << 24) + | (((long) data[offset + 4] & 0xff) << 32) + | (((long) data[offset + 5] & 0xff) << 40) + | (((long) data[offset + 6] & 0xff) << 48) + | (((long) data[offset + 7] & 0xff) << 56); + } + + private static long fmix64(long value) { + long result = value; + result ^= result >>> 33; + result *= 0xff51afd7ed558ccdL; + result ^= result >>> 33; + result *= 0xc4ceb9fe1a85ec53L; + result ^= result >>> 33; + return result; + } + } +} diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/actions/BuildBloomFilterIndexSparkAction.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/actions/BuildBloomFilterIndexSparkAction.java new file mode 100644 index 0000000000..571f808633 --- /dev/null +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/actions/BuildBloomFilterIndexSparkAction.java @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import java.io.IOException; +import java.util.List; +import java.util.UUID; +import org.apache.iceberg.GenericBlobMetadata; +import org.apache.iceberg.GenericStatisticsFile; +import org.apache.iceberg.HasTableOperations; +import org.apache.iceberg.IcebergBuild; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.StatisticsFile; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableOperations; +import org.apache.iceberg.exceptions.RuntimeIOException; +import org.apache.iceberg.io.OutputFile; +import org.apache.iceberg.puffin.Blob; +import org.apache.iceberg.puffin.Puffin; +import org.apache.iceberg.puffin.PuffinWriter; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.JobGroupInfo; +import org.apache.spark.sql.SparkSession; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A minimal Spark action that builds a Bloom-filter-based file index for a single column and stores + * it as a Puffin statistics file. + * + * <p>This is intentionally narrow in scope and intended only as a proof of concept. It computes a + * Bloom filter per data file for a given column and writes all Bloom blobs into a single statistics + * file that is attached to the table metadata for the chosen snapshot. + */ +public class BuildBloomFilterIndexSparkAction + extends BaseSparkAction<BuildBloomFilterIndexSparkAction> { + + private static final Logger LOG = LoggerFactory.getLogger(BuildBloomFilterIndexSparkAction.class); + + public static class Result { + private final List<StatisticsFile> statisticsFiles; + + Result(List<StatisticsFile> statisticsFiles) { + this.statisticsFiles = statisticsFiles; + } + + public List<StatisticsFile> statisticsFiles() { + return statisticsFiles; + } + + public List<org.apache.iceberg.DataFile> rewrittenDataFiles() { + return ImmutableList.of(); + } + + public List<org.apache.iceberg.DataFile> addedDataFiles() { + return ImmutableList.of(); + } + + public List<org.apache.iceberg.DeleteFile> rewrittenDeleteFiles() { + return ImmutableList.of(); + } + + public List<org.apache.iceberg.DeleteFile> addedDeleteFiles() { + return ImmutableList.of(); + } + + @Override + public String toString() { + return String.format("BuildBloomFilterIndexResult(statisticsFiles=%s)", statisticsFiles); + } + } + + private final Table table; + private Snapshot snapshot; + private String column; + + BuildBloomFilterIndexSparkAction(SparkSession spark, Table table) { + super(spark); + this.table = table; + this.snapshot = table.currentSnapshot(); + } + + @Override + protected BuildBloomFilterIndexSparkAction self() { + return this; + } + + public BuildBloomFilterIndexSparkAction column(String columnName) { + Preconditions.checkArgument( + columnName != null && !columnName.isEmpty(), "Column name must not be null/empty"); + this.column = columnName; + return this; + } + + public BuildBloomFilterIndexSparkAction snapshot(long snapshotId) { + Snapshot newSnapshot = table.snapshot(snapshotId); + Preconditions.checkArgument(newSnapshot != null, "Snapshot not found: %s", snapshotId); + this.snapshot = newSnapshot; + return this; + } + + public Result execute() { + Preconditions.checkNotNull(column, "Column must be set before executing Bloom index build"); + if (snapshot == null) { + LOG.info("No snapshot to index for table {}", table.name()); + return new Result(ImmutableList.of()); + } + + JobGroupInfo info = newJobGroupInfo("BUILD-BLOOM-INDEX", jobDesc()); + return withJobGroupInfo(info, this::doExecute); + } + + private Result doExecute() { + LOG.info( + "Building Bloom index for column {} in {} (snapshot {})", + column, + table.name(), + snapshot.snapshotId()); + + List<Blob> blobs = + BloomFilterIndexUtil.buildBloomBlobsForColumn(spark(), table, snapshot, column); + + if (blobs.isEmpty()) { + LOG.info( + "No Bloom blobs generated for column {} in table {} (snapshot {}), skipping write", + column, + table.name(), + snapshot.snapshotId()); + return new Result(ImmutableList.of()); + } + + StatisticsFile statsFile = writeStatsFile(blobs); + table.updateStatistics().setStatistics(statsFile).commit(); + + return new Result(ImmutableList.of(statsFile)); + } + + private StatisticsFile writeStatsFile(List<Blob> blobs) { + LOG.info( + "Writing Bloom index stats for table {} for snapshot {} ({} blob(s))", + table.name(), + snapshot.snapshotId(), + blobs.size()); + OutputFile outputFile = table.io().newOutputFile(outputPath()); + try (PuffinWriter writer = + Puffin.write(outputFile) + .createdBy(appIdentifier()) + .compressBlobs(org.apache.iceberg.puffin.PuffinCompressionCodec.ZSTD) + .build()) { + blobs.forEach(writer::add); + writer.finish(); + return new GenericStatisticsFile( + snapshot.snapshotId(), + outputFile.location(), + writer.fileSize(), + writer.footerSize(), + GenericBlobMetadata.from(writer.writtenBlobsMetadata())); + } catch (IOException e) { + throw new RuntimeIOException(e); + } + } + + private String appIdentifier() { + String icebergVersion = IcebergBuild.fullVersion(); + String sparkVersion = spark().version(); + return String.format("Iceberg %s Spark %s (BloomIndexPOC)", icebergVersion, sparkVersion); + } + + private String jobDesc() { + return String.format( + "Building Bloom index for %s (snapshot_id=%s, column=%s)", + table.name(), snapshot.snapshotId(), column); + } + + private String outputPath() { + TableOperations operations = ((HasTableOperations) table).operations(); + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < column.length(); i++) { + char ch = column.charAt(i); + sb.append(Character.isLetterOrDigit(ch) ? ch : '_'); + } + String sanitizedCol = sb.toString(); + String fileName = + String.format( + "%s-%s-bloom-%s.stats", snapshot.snapshotId(), UUID.randomUUID(), sanitizedCol); + return operations.metadataFileLocation(fileName); + } +} diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/actions/SparkActions.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/actions/SparkActions.java index b7361c336a..e248d2bc73 100644 --- a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/actions/SparkActions.java +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/actions/SparkActions.java @@ -119,4 +119,14 @@ public class SparkActions implements ActionsProvider { public RewriteTablePathSparkAction rewriteTablePath(Table table) { return new RewriteTablePathSparkAction(spark, table); } + + /** + * Build a minimal Bloom-filter-based file index for a single column and store it as a Puffin + * statistics file. + * + * <p>This is a proof-of-concept helper intended for experimentation and benchmarking. + */ + public BuildBloomFilterIndexSparkAction buildBloomFilterIndex(Table table) { + return new BuildBloomFilterIndexSparkAction(spark, table); + } } diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatchQueryScan.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatchQueryScan.java index a361a7f1ba..f29b70c44c 100644 --- a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatchQueryScan.java +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatchQueryScan.java @@ -51,6 +51,7 @@ import org.apache.iceberg.spark.Spark3Util; import org.apache.iceberg.spark.SparkReadConf; import org.apache.iceberg.spark.SparkSchemaUtil; import org.apache.iceberg.spark.SparkV2Filters; +import org.apache.iceberg.spark.actions.BloomFilterIndexUtil; import org.apache.iceberg.util.ContentFileUtil; import org.apache.iceberg.util.DeleteFileSet; import org.apache.iceberg.util.SnapshotUtil; @@ -73,6 +74,10 @@ class SparkBatchQueryScan extends SparkPartitioningAwareScan<PartitionScanTask> private final Long asOfTimestamp; private final String tag; private final List<Expression> runtimeFilterExpressions; + // lazily initialized; used to avoid repeatedly walking filter expressions + private transient String bloomFilterColumn; + private transient Object bloomFilterLiteral; + private transient boolean bloomFilterDetectionAttempted; SparkBatchQueryScan( SparkSession spark, @@ -101,6 +106,32 @@ class SparkBatchQueryScan extends SparkPartitioningAwareScan<PartitionScanTask> return PartitionScanTask.class; } + /** + * Override task planning to optionally apply Bloom-filter-based file pruning based on a simple + * equality predicate. + * + * <p>This uses Bloom-filter blobs stored in Puffin statistics files (see {@link + * org.apache.iceberg.spark.actions.BuildBloomFilterIndexSparkAction}) to prune data files before + * split planning. If the index is missing, stale, or the filters are unsupported, this method + * falls back to the default behavior in {@link SparkPartitioningAwareScan}. + */ + @Override + protected synchronized List<PartitionScanTask> tasks() { + detectBloomFilterPredicateIfNeeded(); + + if (bloomFilterColumn == null || bloomFilterLiteral == null) { + return super.tasks(); + } + + Snapshot snapshot = null; + if (scan() != null) { + snapshot = SnapshotUtil.latestSnapshot(table(), branch()); + } + + return BloomFilterIndexUtil.pruneTasksWithBloomIndex( + table(), snapshot, super::tasks, bloomFilterColumn, bloomFilterLiteral); + } + @Override public NamedReference[] filterAttributes() { Set<Integer> partitionFieldSourceIds = Sets.newHashSet(); @@ -212,6 +243,72 @@ class SparkBatchQueryScan extends SparkPartitioningAwareScan<PartitionScanTask> return runtimeFilterExpr; } + /** + * Detect a simple equality predicate that can take advantage of a Bloom index. + * + * <p>For this proof-of-concept we support only predicates of the form {@code col = literal} on a + * top-level column. Nested fields, non-equality predicates, and complex expressions are ignored. + */ + private void detectBloomFilterPredicateIfNeeded() { + if (bloomFilterDetectionAttempted) { + return; + } + + this.bloomFilterDetectionAttempted = true; + + Schema snapshotSchema = SnapshotUtil.schemaFor(table(), branch()); + + for (Expression expr : filterExpressions()) { + if (expr.op() != Expression.Operation.EQ) { + continue; + } + + try { + Expression bound = Binder.bind(snapshotSchema.asStruct(), expr, caseSensitive()); + if (!(bound instanceof org.apache.iceberg.expressions.BoundPredicate)) { + continue; + } + + org.apache.iceberg.expressions.BoundPredicate<?> predicate = + (org.apache.iceberg.expressions.BoundPredicate<?>) bound; + + if (!(predicate.term() instanceof org.apache.iceberg.expressions.BoundReference)) { + continue; + } + + org.apache.iceberg.expressions.BoundReference<?> ref = + (org.apache.iceberg.expressions.BoundReference<?>) predicate.term(); + int fieldId = ref.fieldId(); + String colName = snapshotSchema.findColumnName(fieldId); + if (colName == null) { + continue; + } + + if (!predicate.isLiteralPredicate()) { + continue; + } + + Object literal = predicate.asLiteralPredicate().literal().value(); + if (literal == null) { + continue; + } + + this.bloomFilterColumn = colName; + this.bloomFilterLiteral = literal; + + LOG.info( + "Detected Bloom-index-eligible predicate {} = {} for table {}", + colName, + literal, + table().name()); + return; + + } catch (Exception e) { + LOG.debug("Failed to bind expression {} for Bloom index detection", expr, e); + } + } + } + @Override public Statistics estimateStatistics() { if (scan() == null) { diff --git a/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/actions/TestBuildBloomFilterIndexAction.java b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/actions/TestBuildBloomFilterIndexAction.java new file mode 100644 index 0000000000..58c729b2e6 --- /dev/null +++ b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/actions/TestBuildBloomFilterIndexAction.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.StatisticsFile; +import org.apache.iceberg.Table; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.CatalogTestBase; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.parser.ParseException; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestBuildBloomFilterIndexAction extends CatalogTestBase { + + @AfterEach + public void removeTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @TestTemplate + public void testBloomIndexActionWritesStatisticsFile() + throws NoSuchTableException, ParseException { + sql("CREATE TABLE %s (id int, data string) USING iceberg", tableName); + + spark + .createDataset( + ImmutableList.of( + new org.apache.iceberg.spark.source.SimpleRecord(1, "a"), + new org.apache.iceberg.spark.source.SimpleRecord(2, "b"), + new org.apache.iceberg.spark.source.SimpleRecord(3, "c")), + Encoders.bean(org.apache.iceberg.spark.source.SimpleRecord.class)) + .writeTo(tableName) + .append(); + + Table table = Spark3Util.loadIcebergTable(spark, tableName); + Snapshot snapshot = table.currentSnapshot(); + assertThat(snapshot).isNotNull(); + + SparkActions actions = SparkActions.get(spark); + BuildBloomFilterIndexSparkAction.Result result = + actions.buildBloomFilterIndex(table).column("id").execute(); + + assertThat(result).isNotNull(); + List<StatisticsFile> statisticsFiles = table.statisticsFiles(); + assertThat(statisticsFiles).isNotEmpty(); + assertThat(statisticsFiles.get(0).fileSizeInBytes()).isGreaterThan(0L); + } + + @TestTemplate + public void testBloomIndexPrunesTasksForEqualityPredicate() + throws NoSuchTableException, ParseException { + sql("CREATE TABLE %s (id int, data string) USING iceberg", tableName); + + // Two groups of values so that only one file should match id = 1 + spark + .createDataset( + ImmutableList.of( + new org.apache.iceberg.spark.source.SimpleRecord(1, "a"), + new org.apache.iceberg.spark.source.SimpleRecord(1, "b"), + new org.apache.iceberg.spark.source.SimpleRecord(2, "c")), + Encoders.bean(org.apache.iceberg.spark.source.SimpleRecord.class)) + .repartition(2) + .writeTo(tableName) + .append(); + + Table table = Spark3Util.loadIcebergTable(spark, tableName); + SparkActions actions = SparkActions.get(spark); + + // Build Bloom index on id + actions.buildBloomFilterIndex(table).column("id").execute(); + table.refresh(); + + // Plan tasks and validate Bloom pruning via the utility directly (unit-level sanity check) + List<FileScanTask> allTasks; + try (CloseableIterable<FileScanTask> planned = table.newScan().planFiles()) { + allTasks = Lists.newArrayList(planned); + } catch (Exception e) { + throw new RuntimeException(e); + } + + List<FileScanTask> prunedTasks = + BloomFilterIndexUtil.pruneTasksWithBloomIndex( + table, table.currentSnapshot(), () -> allTasks, "id", 1); + + assertThat(prunedTasks.size()).isLessThanOrEqualTo(allTasks.size()); + } +} diff --git a/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/benchmark/TestBloomFilterIndexBenchmark.java b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/benchmark/TestBloomFilterIndexBenchmark.java new file mode 100644 index 0000000000..4510fe5b93 --- /dev/null +++ b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/benchmark/TestBloomFilterIndexBenchmark.java @@ -0,0 +1,894 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.benchmark; + +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.security.SecureRandom; +import java.util.Comparator; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.StatisticsFile; +import org.apache.iceberg.Table; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.InputFile; +import org.apache.iceberg.puffin.BlobMetadata; +import org.apache.iceberg.puffin.FileMetadata; +import org.apache.iceberg.puffin.Puffin; +import org.apache.iceberg.puffin.PuffinReader; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.actions.BloomFilterIndexUtil; +import org.apache.iceberg.spark.actions.BuildBloomFilterIndexSparkAction; +import org.apache.iceberg.spark.actions.SparkActions; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.execution.SparkPlan; +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec; +import org.apache.spark.sql.execution.metric.SQLMetric; +import org.apache.spark.sql.functions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; +import scala.collection.JavaConverters; + +/** + * A local-only benchmark for the Puffin "bloom-filter-v1" file-skipping index. + * + * <p>Disabled by default. Run explicitly with: + * + * <pre> + * ICEBERG_BLOOM_INDEX_BENCHMARK=true ./gradlew -DsparkVersions=4.1 \ + * :iceberg-spark:iceberg-spark-4.1_2.13:test \ + * --tests org.apache.iceberg.spark.benchmark.TestBloomFilterIndexBenchmark + * </pre> + * + * <p>Optional tuning via system properties: + * + * <ul> + * <li>iceberg.benchmark.numFiles (default: 1000) + * <li>iceberg.benchmark.rowsPerFile (default: 2000) + * <li>iceberg.benchmark.warmupRuns (default: 3) + * <li>iceberg.benchmark.measuredRuns (default: 10) + * <li>iceberg.benchmark.numDays (default: 10) for the partitioned table + * </ul> + */ +public class TestBloomFilterIndexBenchmark { + + @TempDir private Path tempDir; + + private static final String BLOOM_BLOB_TYPE = "bloom-filter-v1"; + private static final String PROP_DATA_FILE = "data-file"; + private static final String PROP_NUM_BITS = "num-bits"; + private static final String PROP_NUM_HASHES = "num-hashes"; + + @Test + public void runBloomIndexBenchmark() { + boolean enabled = + Boolean.parseBoolean(System.getenv().getOrDefault("ICEBERG_BLOOM_INDEX_BENCHMARK", "false")) + || Boolean.getBoolean("iceberg.bloomIndexBenchmark"); + assumeTrue( + enabled, + "Benchmark disabled. Re-run with ICEBERG_BLOOM_INDEX_BENCHMARK=true (or -Diceberg.bloomIndexBenchmark=true)"); + + BenchmarkConfig config = BenchmarkConfig.fromSystemProperties(); + + // Keep needles inside the same lexical domain as the table values (hex strings), + // otherwise min/max pruning can hide the bloom index benefit. + String needle = randomHex64(); + String miss = randomHex64(); + while (miss.equals(needle)) { + miss = randomHex64(); + } + + StringBuilder report = new StringBuilder(); + + Path warehouse = tempDir.resolve("warehouse"); + SparkSession spark = newSpark(warehouse, config.shufflePartitions()); + + try { + spark.sql("CREATE NAMESPACE IF NOT EXISTS bench.default"); + + log(report, ""); + log(report, "==== Iceberg Bloom Index Benchmark (local) ===="); + log(report, "numFiles=" + config.numFiles() + ", rowsPerFile=" + config.rowsPerFile()); + log(report, "warmupRuns=" + config.warmupRuns() + ", measuredRuns=" + config.measuredRuns()); + log(report, "needle=" + needle); + log(report, ""); + + for (DatasetProfile profile : DatasetProfile.profilesToRun()) { + runDatasetProfile(spark, config, needle, miss, profile, report); + } + + } finally { + spark.stop(); + } + + writeReport(report); + } + + private static String randomHex64() { + byte[] bytes = new byte[32]; + new SecureRandom().nextBytes(bytes); + StringBuilder sb = new StringBuilder(64); + for (byte b : bytes) { + sb.append(String.format(Locale.ROOT, "%02x", b)); + } + return sb.toString(); + } + + private enum DatasetProfile { + STRESS("stress", "Stress (min/max defeated)", true), + REALISTIC("real", "Realistic (random high-cardinality)", false); + + private final String suffix; + private final String displayName; + private final boolean defeatMinMax; + + DatasetProfile(String suffix, String displayName, boolean defeatMinMax) { + this.suffix = suffix; + this.displayName = displayName; + this.defeatMinMax = defeatMinMax; + } + + String suffix() { + return suffix; + } + + String displayName() { + return displayName; + } + + boolean defeatMinMax() { + return defeatMinMax; + } + + static List<DatasetProfile> profilesToRun() { + // Optional filter: ICEBERG_BENCHMARK_DATASETS=stress,real (default: both). + String raw = System.getenv().getOrDefault("ICEBERG_BENCHMARK_DATASETS", ""); + if (raw == null || raw.trim().isEmpty()) { + return List.of(STRESS, REALISTIC); + } + + String lowered = raw.toLowerCase(Locale.ROOT); + boolean wantStress = lowered.contains("stress"); + boolean wantReal = lowered.contains("real"); + + if (wantStress && wantReal) { + return List.of(STRESS, REALISTIC); + } else if (wantStress) { + return List.of(STRESS); + } else if (wantReal) { + return List.of(REALISTIC); + } + + // Unknown value; fall back to both. + return List.of(STRESS, REALISTIC); + } + } + + private static void log(StringBuilder report, String line) { + System.out.println(line); + report.append(line).append(System.lineSeparator()); + } + + private static void writeReport(StringBuilder report) { + // Write under dev/ so it's easy to find and not ignored by tooling. + Path reportPath = Paths.get("dev/bloom-index-benchmark.txt"); + try { + Files.createDirectories(reportPath.getParent()); + Files.writeString(reportPath, report.toString()); + System.out.println(); + System.out.println("Benchmark report written to: " + reportPath.toAbsolutePath()); + } catch (IOException e) { + throw new RuntimeException("Failed to write benchmark report", e); + } + } + + private static void logTableFileCount( + StringBuilder report, SparkSession spark, String tableName, String label) { + try { + Table table = Spark3Util.loadIcebergTable(spark, tableName); + int totalFiles = 0; + try (CloseableIterable<FileScanTask> tasks = table.newScan().planFiles()) { + for (FileScanTask ignored : tasks) { + totalFiles++; + } + } + log(report, "Table files (" + label + "): " + tableName + " -> " + totalFiles); + } catch (Exception e) { + log(report, "Table files (" + label + "): " + tableName + " -> n/a (" + e + ")"); + } + } + + private static void logBloomIndexArtifactStats( + StringBuilder report, SparkSession spark, String tableName, String columnName, String label) { + try { + Table table = Spark3Util.loadIcebergTable(spark, tableName); + Snapshot snapshot = table.currentSnapshot(); + if (snapshot == null) { + log(report, "Index artifacts (" + label + "): no current snapshot"); + return; + } + + org.apache.iceberg.types.Types.NestedField field = table.schema().findField(columnName); + if (field == null) { + log(report, "Index artifacts (" + label + "): column not found: " + columnName); + return; + } + + int fieldId = field.fieldId(); + List<StatisticsFile> statsFilesForSnapshot = + statsFilesForSnapshot(table, snapshot.snapshotId()); + BloomIndexArtifactStats stats = scanBloomIndexBlobs(table, statsFilesForSnapshot, fieldId); + + log( + report, + String.format( + Locale.ROOT, + "Index artifacts (%s): statsFiles=%d statsBytes=%d bloomBlobs=%d bloomPayloadBytes=%d coveredDataFiles=%d blobsMissingRequiredProps=%d", + label, + stats.statsFileCount(), + stats.statsFilesBytes(), + stats.bloomBlobCount(), + stats.bloomPayloadBytes(), + stats.coveredDataFiles(), + stats.bloomBlobsMissingRequiredProps())); + + } catch (Exception e) { + log(report, "Index artifacts (" + label + "): n/a (" + e + ")"); + } + } + + private static List<StatisticsFile> statsFilesForSnapshot(Table table, long snapshotId) { + return table.statisticsFiles().stream() + .filter(sf -> sf.snapshotId() == snapshotId) + .collect(Collectors.toList()); + } + + private static BloomIndexArtifactStats scanBloomIndexBlobs( + Table table, List<StatisticsFile> statsFilesForSnapshot, int fieldId) throws IOException { + int statsFileCount = statsFilesForSnapshot.size(); + long statsFilesBytes = + statsFilesForSnapshot.stream().mapToLong(StatisticsFile::fileSizeInBytes).sum(); + + int bloomBlobCount = 0; + long bloomPayloadBytes = 0L; + int bloomBlobsMissingRequiredProps = 0; + Set<String> dataFilesCovered = Sets.newHashSet(); + + for (StatisticsFile stats : statsFilesForSnapshot) { + InputFile inputFile = table.io().newInputFile(stats.path()); + try (PuffinReader reader = + Puffin.read(inputFile).withFileSize(stats.fileSizeInBytes()).build()) { + FileMetadata fileMetadata = reader.fileMetadata(); + for (BlobMetadata bm : fileMetadata.blobs()) { + if (!isBloomBlobForField(bm, fieldId)) { + continue; + } + + bloomBlobCount++; + bloomPayloadBytes += bm.length(); + + Map<String, String> props = bm.properties(); + if (props == null) { + bloomBlobsMissingRequiredProps++; + continue; + } + + String dataFile = props.get(PROP_DATA_FILE); + if (dataFile != null) { + dataFilesCovered.add(dataFile); + } + + if (!hasRequiredBloomProps(props)) { + bloomBlobsMissingRequiredProps++; + } + } + } + } + + return new BloomIndexArtifactStats( + statsFileCount, + statsFilesBytes, + bloomBlobCount, + bloomPayloadBytes, + dataFilesCovered.size(), + bloomBlobsMissingRequiredProps); + } + + private static boolean isBloomBlobForField(BlobMetadata bm, int fieldId) { + if (!BLOOM_BLOB_TYPE.equals(bm.type())) { + return false; + } + + List<Integer> fields = bm.inputFields(); + if (fields == null || fields.isEmpty()) { + return false; + } + + return fields.get(0) == fieldId; + } + + private static boolean hasRequiredBloomProps(Map<String, String> props) { + return props.get(PROP_NUM_BITS) != null && props.get(PROP_NUM_HASHES) != null; + } + + private static class BloomIndexArtifactStats { + private final int statsFileCount; + private final long statsFilesBytes; + private final int bloomBlobCount; + private final long bloomPayloadBytes; + private final int coveredDataFiles; + private final int bloomBlobsMissingRequiredProps; + + private BloomIndexArtifactStats( + int statsFileCount, + long statsFilesBytes, + int bloomBlobCount, + long bloomPayloadBytes, + int coveredDataFiles, + int bloomBlobsMissingRequiredProps) { + this.statsFileCount = statsFileCount; + this.statsFilesBytes = statsFilesBytes; + this.bloomBlobCount = bloomBlobCount; + this.bloomPayloadBytes = bloomPayloadBytes; + this.coveredDataFiles = coveredDataFiles; + this.bloomBlobsMissingRequiredProps = bloomBlobsMissingRequiredProps; + } + + private int statsFileCount() { + return statsFileCount; + } + + private long statsFilesBytes() { + return statsFilesBytes; + } + + private int bloomBlobCount() { + return bloomBlobCount; + } + + private long bloomPayloadBytes() { + return bloomPayloadBytes; + } + + private int coveredDataFiles() { + return coveredDataFiles; + } + + private int bloomBlobsMissingRequiredProps() { + return bloomBlobsMissingRequiredProps; + } + } + + private static SparkSession newSpark(Path warehouse, int shufflePartitions) { + return SparkSession.builder() + .appName("IcebergBloomIndexBenchmark") + .master("local[*]") + .config("spark.ui.enabled", "false") + // Use loopback to avoid artifact/classloader RPC issues in local tests. + .config("spark.driver.host", "127.0.0.1") + .config("spark.driver.bindAddress", "127.0.0.1") + .config("spark.sql.shuffle.partitions", String.valueOf(shufflePartitions)) + .config("spark.sql.adaptive.enabled", "false") + .config("spark.sql.catalog.bench", "org.apache.iceberg.spark.SparkCatalog") + .config("spark.sql.catalog.bench.type", "hadoop") + .config("spark.sql.catalog.bench.warehouse", warehouse.toAbsolutePath().toString()) + .getOrCreate(); + } + + private static void createTable( + SparkSession spark, String tableName, boolean partitioned, boolean parquetBloom) { + spark.sql("DROP TABLE IF EXISTS " + tableName); + + String partitionClause = partitioned ? "PARTITIONED BY (day)" : ""; + + StringBuilder props = new StringBuilder(); + props.append("'write.format.default'='parquet'"); + props.append(", 'write.distribution-mode'='none'"); + props.append(", 'write.spark.fanout.enabled'='true'"); + // Keep files small so we get many files even on local. + props.append(", 'write.target-file-size-bytes'='1048576'"); + + if (parquetBloom) { + props.append(", 'write.parquet.bloom-filter-enabled.column.id'='true'"); + props.append(", 'write.parquet.bloom-filter-fpp.column.id'='0.01'"); + } + + spark.sql( + String.format( + Locale.ROOT, + "CREATE TABLE %s (day string, id string, file_id int, payload string) USING iceberg %s " + + "TBLPROPERTIES (%s)", + tableName, + partitionClause, + props)); + } + + private static void runDatasetProfile( + SparkSession spark, + BenchmarkConfig config, + String needle, + String miss, + DatasetProfile profile, + StringBuilder report) { + + String suffix = profile.suffix(); + boolean defeatMinMax = profile.defeatMinMax(); + + String baseTable = "bench.default.bloom_bench_base_" + suffix; + String rowGroupTable = "bench.default.bloom_bench_rowgroup_" + suffix; + String puffinTable = "bench.default.bloom_bench_puffin_" + suffix; + String partitionedPuffinTable = "bench.default.bloom_bench_part_puffin_" + suffix; + + log(report, "==== Dataset: " + profile.displayName() + " ===="); + log(report, "defeatMinMax=" + defeatMinMax); + if (defeatMinMax) { + log( + report, + "Notes: Each data file contains id=min(0x00..00) and id=max(0xff..ff) rows so min/max stats cannot prune for id = <randomHex>."); + log( + report, + " The needle value is injected into exactly one file; the miss value is not present in any file."); + log( + report, + " This isolates the incremental benefit of Puffin bloom-filter-v1 file-level pruning vs baseline/row-group bloom."); + } else { + log( + report, + "Notes: High-cardinality random IDs (UUID/hash-like). For most rows: id = sha2(\"salt:file_id:row_id\", 256)."); + log( + report, + " The needle value is injected into exactly one file; the miss value is not present in any file."); + } + + // A) baseline: no parquet row-group bloom, no puffin bloom index + createTable(spark, baseTable, false /* partitioned */, false /* parquetBloom */); + writeData(spark, baseTable, config, false /* partitioned */, needle, defeatMinMax); + logTableFileCount(report, spark, baseTable, "base-" + suffix); + + // B) row-group bloom only: enable Parquet bloom filters for 'id' + createTable(spark, rowGroupTable, false /* partitioned */, true /* parquetBloom */); + writeData(spark, rowGroupTable, config, false /* partitioned */, needle, defeatMinMax); + logTableFileCount(report, spark, rowGroupTable, "rowgroup-" + suffix); + + // C) puffin bloom index only: build puffin bloom index on id + createTable(spark, puffinTable, false /* partitioned */, false /* parquetBloom */); + writeData(spark, puffinTable, config, false /* partitioned */, needle, defeatMinMax); + logTableFileCount(report, spark, puffinTable, "puffin-before-index-" + suffix); + buildPuffinBloomIndex(spark, puffinTable, "id"); + logBloomIndexArtifactStats(report, spark, puffinTable, "id", "puffin-index-" + suffix); + + // D) partitioned + puffin bloom index: demonstrate benefit within a pruned partition + createTable(spark, partitionedPuffinTable, true /* partitioned */, false /* parquetBloom */); + writeData(spark, partitionedPuffinTable, config, true /* partitioned */, needle, defeatMinMax); + logTableFileCount( + report, spark, partitionedPuffinTable, "partitioned-puffin-before-index-" + suffix); + buildPuffinBloomIndex(spark, partitionedPuffinTable, "id"); + logBloomIndexArtifactStats( + report, spark, partitionedPuffinTable, "id", "partitioned-puffin-index-" + suffix); + + log(report, ""); + + // Unpartitioned: point predicate + benchmarkScenario( + spark, + "A) baseline (no row-group bloom, no puffin index)", + baseTable, + "SELECT count(*) AS c FROM %s WHERE id = '%s'", + needle, + miss, + config, + report); + + benchmarkScenario( + spark, + "B) row-group bloom only (Parquet row-group bloom on id)", + rowGroupTable, + "SELECT count(*) AS c FROM %s WHERE id = '%s'", + needle, + miss, + config, + report); + + benchmarkScenario( + spark, + "C) puffin bloom index only (file-skipping bloom-filter-v1 on id)", + puffinTable, + "SELECT count(*) AS c FROM %s WHERE id = '%s'", + needle, + miss, + config, + report); + + // Partitioned: prune by day then apply puffin bloom inside the remaining partition + String needleDay = config.needleDay(); + String missDay = needleDay; + benchmarkScenario( + spark, + "D) partitioned + puffin bloom index (day + id predicates)", + partitionedPuffinTable, + "SELECT count(*) AS c FROM %s WHERE day = '%s' AND id = '%s'", + needleDay + "|" + needle, + missDay + "|" + miss, + config, + report); + } + + private static void writeData( + SparkSession spark, + String tableName, + BenchmarkConfig config, + boolean partitioned, + String needle, + boolean defeatMinMax) { + int numFiles = config.numFiles(); + int rowsPerFile = config.rowsPerFile(); + int numDays = Math.max(1, config.numDays()); + + // Choose a single file to contain the needle value. + int needleFileId = Math.max(0, numFiles / 2); + int needleRepeats = Math.min(50, Math.max(1, rowsPerFile / 20)); + + long totalRows = (long) numFiles * (long) rowsPerFile; + Dataset<Row> df = spark.range(totalRows).toDF(); + + Column fileId = df.col("id").divide(rowsPerFile).cast("int").alias("file_id"); + Column posInFile = + functions.pmod(df.col("id"), functions.lit(rowsPerFile)).cast("int").alias("pos_in_file"); + df = df.select(df.col("id").alias("row_id"), fileId, posInFile); + + Column dayCol; + if (partitioned) { + // Stable day string: 2026-01-01 + (file_id % numDays). + dayCol = + functions + .date_format( + functions.date_add( + functions.lit("2026-01-01").cast("date"), + functions.pmod(df.col("file_id"), functions.lit(numDays))), + "yyyy-MM-dd") + .alias("day"); + } else { + dayCol = functions.lit("1970-01-01").alias("day"); + } + + // Deterministic high-cardinality IDs so per-file min/max is generally not selective. + Column randomId = + functions + .sha2( + functions.concat_ws( + ":", functions.lit("salt"), df.col("file_id"), df.col("row_id")), + 256) + .alias("random_id"); + + Column idCol; + if (defeatMinMax) { + String minId = "0000000000000000000000000000000000000000000000000000000000000000"; + String maxId = "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"; + idCol = + functions + // Force identical min/max across files to defeat min/max pruning. + .when(df.col("pos_in_file").equalTo(0), functions.lit(minId)) + .when(df.col("pos_in_file").equalTo(1), functions.lit(maxId)) + // Inject the needle into exactly one file. + .when( + df.col("file_id") + .equalTo(needleFileId) + .and(df.col("pos_in_file").geq(2)) + .and(df.col("pos_in_file").lt(needleRepeats + 2)), + functions.lit(needle)) + .otherwise(randomId) + .alias("id"); + } else { + idCol = + functions + // Inject the needle into exactly one file. + .when( + df.col("file_id") + .equalTo(needleFileId) + .and(df.col("pos_in_file").lt(needleRepeats)), + functions.lit(needle)) + .otherwise(randomId) + .alias("id"); + } + + Column payload = + functions + .sha2(functions.concat_ws("-", functions.lit("p"), df.col("row_id")), 256) + .alias("payload"); + + Dataset<Row> out = df.select(dayCol, idCol, df.col("file_id"), payload); + + // Encourage one output file per file_id (and per day when partitioned). + Dataset<Row> repartitioned = + partitioned + ? out.repartition(numFiles, out.col("day"), out.col("file_id")) + : out.repartition(numFiles, out.col("file_id")); + + try { + repartitioned.writeTo(tableName).append(); + } catch (org.apache.spark.sql.catalyst.analysis.NoSuchTableException e) { + throw new RuntimeException("Table not found: " + tableName, e); + } + } + + private static void buildPuffinBloomIndex( + SparkSession spark, String tableName, String columnName) { + Table table; + try { + table = Spark3Util.loadIcebergTable(spark, tableName); + } catch (org.apache.spark.sql.catalyst.analysis.NoSuchTableException + | org.apache.spark.sql.catalyst.parser.ParseException e) { + throw new RuntimeException("Failed to load Iceberg table: " + tableName, e); + } + BuildBloomFilterIndexSparkAction.Result result = + SparkActions.get(spark).buildBloomFilterIndex(table).column(columnName).execute(); + table.refresh(); + } + + private static void benchmarkScenario( + SparkSession spark, + String label, + String tableName, + String queryTemplate, + String needleValue, + String missValue, + BenchmarkConfig config, + StringBuilder report) { + + log(report, "---- " + label + " ----"); + + ScenarioResult needle = + runQueryWorkload(spark, tableName, queryTemplate, needleValue, config, true /* hasDay */); + ScenarioResult miss = + runQueryWorkload(spark, tableName, queryTemplate, missValue, config, true /* hasDay */); + + log(report, "Needle: " + needle.summary()); + log(report, "Miss : " + miss.summary()); + log(report, ""); + } + + private static ScenarioResult runQueryWorkload( + SparkSession spark, + String tableName, + String queryTemplate, + String value, + BenchmarkConfig config, + boolean allowCompoundValue) { + + String query; + String dayFilter = null; + String idFilter; + if (allowCompoundValue && value.contains("|") && queryTemplate.contains("day =")) { + String[] parts = value.split("\\|", 2); + dayFilter = parts[0]; + idFilter = parts[1]; + query = String.format(Locale.ROOT, queryTemplate, tableName, dayFilter, idFilter); + } else { + idFilter = value; + query = String.format(Locale.ROOT, queryTemplate, tableName, idFilter); + } + + FileCounts fileCounts = estimateFileCounts(spark, tableName, dayFilter, idFilter); + + // Warmup + for (int i = 0; i < config.warmupRuns(); i++) { + spark.sql(query).collectAsList(); + } + + List<Long> durationsMs = Lists.newArrayList(); + List<Long> numFiles = Lists.newArrayList(); + List<Long> bytes = Lists.newArrayList(); + + for (int i = 0; i < config.measuredRuns(); i++) { + QueryMetrics queryMetrics = timedCollect(spark, query); + durationsMs.add(queryMetrics.durationMs()); + numFiles.add(queryMetrics.numFiles()); + bytes.add(queryMetrics.bytesRead()); + } + + return new ScenarioResult( + durationsMs, numFiles, bytes, fileCounts.plannedFiles(), fileCounts.afterBloomFiles()); + } + + private static FileCounts estimateFileCounts( + SparkSession spark, String tableName, String dayFilter, String idFilter) { + try { + Table table = Spark3Util.loadIcebergTable(spark, tableName); + org.apache.iceberg.TableScan scan = table.newScan().filter(Expressions.equal("id", idFilter)); + if (dayFilter != null) { + scan = scan.filter(Expressions.equal("day", dayFilter)); + } + + List<FileScanTask> plannedTasks = Lists.newArrayList(); + try (CloseableIterable<FileScanTask> tasks = scan.planFiles()) { + for (FileScanTask t : tasks) { + plannedTasks.add(t); + } + } + + List<FileScanTask> afterBloom = + BloomFilterIndexUtil.pruneTasksWithBloomIndex( + table, table.currentSnapshot(), () -> plannedTasks, "id", idFilter); + + return new FileCounts(plannedTasks.size(), afterBloom.size()); + + } catch (Exception e) { + // If anything goes wrong, don't fail the benchmark; just omit file counts. + return new FileCounts(-1, -1); + } + } + + private static QueryMetrics timedCollect(SparkSession spark, String query) { + Dataset<Row> ds = spark.sql(query); + long startNs = System.nanoTime(); + ds.collectAsList(); + long durationMs = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startNs); + + SparkPlan plan = ds.queryExecution().executedPlan(); + if (plan instanceof AdaptiveSparkPlanExec) { + plan = ((AdaptiveSparkPlanExec) plan).executedPlan(); + } + + long numFiles = sumMetricByName(plan, "numFiles"); + long bytesRead = sumMetricsMatching(plan, "bytes"); + + return new QueryMetrics(durationMs, numFiles, bytesRead); + } + + private static long sumMetricByName(SparkPlan plan, String metricName) { + long sum = 0L; + Map<String, SQLMetric> metrics = JavaConverters.mapAsJavaMap(plan.metrics()); + SQLMetric metric = metrics.get(metricName); + if (metric != null) { + sum += metric.value(); + } + + scala.collection.Iterator<SparkPlan> it = plan.children().iterator(); + while (it.hasNext()) { + sum += sumMetricByName(it.next(), metricName); + } + + return sum; + } + + private static long sumMetricsMatching(SparkPlan plan, String tokenLower) { + long sum = 0L; + Map<String, SQLMetric> metrics = JavaConverters.mapAsJavaMap(plan.metrics()); + for (Map.Entry<String, SQLMetric> entry : metrics.entrySet()) { + String key = entry.getKey(); + if (key != null && key.toLowerCase(Locale.ROOT).contains(tokenLower)) { + sum += entry.getValue().value(); + } + } + + scala.collection.Iterator<SparkPlan> it = plan.children().iterator(); + while (it.hasNext()) { + sum += sumMetricsMatching(it.next(), tokenLower); + } + + return sum; + } + + private record QueryMetrics(long durationMs, long numFiles, long bytesRead) {} + + private record ScenarioResult( + List<Long> durationsMs, + List<Long> numFiles, + List<Long> bytesRead, + int plannedFiles, + int afterBloomFiles) { + String summary() { + long p50 = percentile(durationsMs, 0.50); + long p95 = percentile(durationsMs, 0.95); + long filesMedian = percentile(numFiles, 0.50); + long bytesMedian = percentile(bytesRead, 0.50); + String planned = plannedFiles >= 0 ? String.valueOf(plannedFiles) : "n/a"; + String afterBloom = afterBloomFiles >= 0 ? String.valueOf(afterBloomFiles) : "n/a"; + return String.format( + Locale.ROOT, + "latency_ms(p50=%d,p95=%d) scanMetricFiles(p50=%d) plannedFiles=%s afterBloom=%s bytesMetric(p50=%d)", + p50, + p95, + filesMedian, + planned, + afterBloom, + bytesMedian); + } + } + + private record FileCounts(int plannedFiles, int afterBloomFiles) {} + + private static long percentile(List<Long> values, double percentile) { + if (values.isEmpty()) { + return 0L; + } + + List<Long> sorted = Lists.newArrayList(values); + sorted.sort(Comparator.naturalOrder()); + + int size = sorted.size(); + int idx = (int) Math.ceil(percentile * size) - 1; + idx = Math.max(0, Math.min(size - 1, idx)); + return sorted.get(idx); + } + + private record BenchmarkConfig( + int numFiles, int rowsPerFile, int warmupRuns, int measuredRuns, int numDays) { + + static BenchmarkConfig fromSystemProperties() { + return new BenchmarkConfig( + intProp("ICEBERG_BENCHMARK_NUM_FILES", "iceberg.benchmark.numFiles", 1000), + intProp("ICEBERG_BENCHMARK_ROWS_PER_FILE", "iceberg.benchmark.rowsPerFile", 2000), + intProp("ICEBERG_BENCHMARK_WARMUP_RUNS", "iceberg.benchmark.warmupRuns", 3), + intProp("ICEBERG_BENCHMARK_MEASURED_RUNS", "iceberg.benchmark.measuredRuns", 10), + intProp("ICEBERG_BENCHMARK_NUM_DAYS", "iceberg.benchmark.numDays", 10)); + } + + int shufflePartitions() { + // Use a stable upper bound to keep planning reasonable locally. + return Math.max(8, Math.min(2000, numFiles)); + } + + String needleDay() { + // Needle is injected into file_id=numFiles/2, which maps to day=(file_id % numDays). + int dayOffset = Math.floorMod(Math.max(0, numFiles / 2), Math.max(1, numDays)); + // 2026-01-01 + offset days + java.time.LocalDate base = java.time.LocalDate.of(2026, 1, 1); + return base.plusDays(dayOffset).toString(); + } + + private static int intProp(String envKey, String sysPropKey, int defaultValue) { + String envValue = System.getenv(envKey); + if (envValue != null && !envValue.isEmpty()) { + try { + return Integer.parseInt(envValue); + } catch (NumberFormatException ignored) { + // fall back to sys prop/default + } + } + + String sysPropValue = System.getProperty(sysPropKey); + if (sysPropValue == null || sysPropValue.isEmpty()) { + return defaultValue; + } + + try { + return Integer.parseInt(sysPropValue); + } catch (NumberFormatException e) { + return defaultValue; + } + } + } +}
