shangxinli commented on code in PR #18729: URL: https://github.com/apache/hudi/pull/18729#discussion_r3266785021
########## hudi-examples/hudi-examples-spark/src/test/python/vector_blob_demo/hudi_vector_search_batch_demo.py: ########## @@ -0,0 +1,713 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Hudi BATCH vector search demo — certifies `hudi_vector_search_batch` at non- +trivial scale against a numpy ground-truth oracle. + +Sibling to `hudi_sql_vector_blob_demo.py` (single-query). Same dataset, same +embedding model, same Hudi DDL — but the search step is the **table-to-table** +batch TVF described in RFC-102: + + SELECT * + FROM hudi_vector_search_batch( + 'pets_batch_corpus_<format>', 'embedding', + 'pets_batch_queries_<format>', 'embedding', + k, 'cosine') + +Flow: + 1. Load N_CORPUS + N_QUERIES Oxford-IIIT Pet images. + 2. Generate L2-normalized embeddings with `mobilenetv3_small_100`. + 3. Split into corpus (N_CORPUS) + held-out queries (N_QUERIES). + 4. Stage both via PyArrow, write each to its own Hudi table. + 5. Run `hudi_vector_search_batch` and collect results. + 6. **Oracle validation:** compute the cosine distance matrix in numpy from + the same embeddings and assert the TVF's top-k per query matches. + 7. Render a result panel (one row per query showing its top-k matches). + +Env vars: + HUDI_BUNDLE_JAR (defaults to ~/Downloads/hudi-spark3.5-bundle_2.12-1.2.0-rc1.jar) + HUDI_BASE_FILE_FORMAT (default 'lance'; set to 'parquet' to use Parquet) + LANCE_BUNDLE_JAR (defaults to ~/Downloads/lance-spark-bundle-3.5_2.12-0.4.0.jar; only used when HUDI_BASE_FILE_FORMAT=lance) + HUDI_BATCH_N_CORPUS (default 1000; rows in the corpus Hudi table) + HUDI_BATCH_N_QUERIES (default 20; rows in the query Hudi table) + HUDI_BATCH_TOP_K (default 5) + PYSPARK_DRIVER_MEMORY (default '4g') + HUDI_LANCE_DEMO_OUTDIR (default './outputs') +""" + +import io +import os +import shutil +import sys +from pathlib import Path + +# MUST run before any `pyspark` import — local-mode driver heap is fixed at +# JVM launch time and cannot be raised via SparkSession.config() later. +_driver_mem = os.getenv("PYSPARK_DRIVER_MEMORY", "4g") +os.environ.setdefault( + "PYSPARK_SUBMIT_ARGS", + f"--driver-memory {_driver_mem} --conf spark.driver.maxResultSize=2g pyspark-shell", +) + +import numpy as np +import pyarrow as pa +import pyarrow.parquet as pq +import torch +import timm +from sklearn.preprocessing import normalize +from PIL import Image + +import matplotlib + +matplotlib.use("Agg") +import matplotlib.pyplot as plt # noqa: E402 + +from torchvision.datasets import OxfordIIITPet # noqa: E402 + +from pyspark.sql import SparkSession + + +# ====================================================== +# CONFIGURATION +# ====================================================== + +_file_format = os.getenv("HUDI_BASE_FILE_FORMAT", "lance").lower() +if _file_format not in ("lance", "parquet"): + sys.exit(f"ERROR: HUDI_BASE_FILE_FORMAT must be 'lance' or 'parquet', got '{_file_format}'") + +CONFIG = { + "dataset": "OxfordIIITPet", + "base_file_format": _file_format, + "corpus_table_path": f"/tmp/hudi_batch_corpus_{_file_format}_pets", + "corpus_table_name": f"pets_batch_corpus_{_file_format}", + "queries_table_path": f"/tmp/hudi_batch_queries_{_file_format}_pets", + "queries_table_name": f"pets_batch_queries_{_file_format}", + "n_corpus": int(os.getenv("HUDI_BATCH_N_CORPUS", "1000")), + "n_queries": int(os.getenv("HUDI_BATCH_N_QUERIES", "20")), + "top_k": int(os.getenv("HUDI_BATCH_TOP_K", "5")), + "embedding_model": "mobilenetv3_small_100", + "output_dir": os.getenv("HUDI_LANCE_DEMO_OUTDIR", "./outputs"), + "panel_filename": f"hudi_vector_search_batch_{_file_format}_results.png", + "log_level": "ERROR", + "hide_progress": True, + # Oracle tolerance: cosine distance computed on L2-normalized float32 vectors + # in numpy vs JVM-side DenseVector(Double) UDF. Float32 → Float64 widening + + # different summation orders allow ~1e-5 deltas. + "oracle_distance_tol": 1e-5, +} + +BLOB_REFERENCE_CAST = ( + "struct<external_path:string,offset:bigint,length:bigint,managed:boolean>" +) + + +# ====================================================== +# UTILITIES +# ====================================================== + +def ensure_dir(p: Path) -> None: + p.mkdir(parents=True, exist_ok=True) + + +def wipe_prior_state() -> None: + """ + Remove this script's prior table dirs and staging Parquets so re-runs are + idempotent. `DROP TABLE IF EXISTS` (run inside `create_hudi_table_sql`) only + removes the catalog entry — the data dir and `.hoodie/` timeline at + LOCATION persist, so a re-run would query stale rows alongside fresh ones + and the oracle would (correctly) flag the mismatch. + """ + targets = [ + CONFIG["corpus_table_path"], + CONFIG["queries_table_path"], + f"/tmp/staging_pets_batch_corpus_{CONFIG['base_file_format']}.parquet", + f"/tmp/staging_pets_batch_queries_{CONFIG['base_file_format']}.parquet", + ] + for t in targets: + p = Path(t) + if p.is_dir(): + shutil.rmtree(p, ignore_errors=True) + elif p.is_file(): + p.unlink(missing_ok=True) + # Catalog warehouse from prior runs in this cwd. + shutil.rmtree("spark-warehouse", ignore_errors=True) + + +def save_png_bytes(img_bytes: bytes, path: Path) -> None: + ensure_dir(path.parent) + with open(path, "wb") as f: + f.write(img_bytes) + + +def default_hudi_bundle_jar() -> str: + return str(Path.home() / "Downloads" / "hudi-spark3.5-bundle_2.12-1.2.0-rc1.jar") + + +def default_lance_bundle_jar() -> str: + return str(Path.home() / "Downloads" / "lance-spark-bundle-3.5_2.12-0.4.0.jar") + + +def resolve_jars() -> str: + hudi_jar = os.getenv("HUDI_BUNDLE_JAR", default_hudi_bundle_jar()) + if not Path(hudi_jar).is_file(): + sys.exit( + f"ERROR: HUDI_BUNDLE_JAR does not exist at {hudi_jar}\n" + "Download the Apache 1.2.0-rc1 staging jar with:\n" + " curl -L -o ~/Downloads/hudi-spark3.5-bundle_2.12-1.2.0-rc1.jar \\\n" + " https://repository.apache.org/content/repositories/orgapachehudi-1176/org/apache/hudi/hudi-spark3.5-bundle_2.12/1.2.0-rc1/hudi-spark3.5-bundle_2.12-1.2.0-rc1.jar\n" + "or set HUDI_BUNDLE_JAR=/abs/path/to/locally-built.jar." + ) + if CONFIG["base_file_format"] != "lance": + return hudi_jar + lance_jar = os.getenv("LANCE_BUNDLE_JAR", default_lance_bundle_jar()) + if not Path(lance_jar).is_file(): + sys.exit( + f"ERROR: LANCE_BUNDLE_JAR does not exist at {lance_jar}\n" + "Download the Lance 0.4.0 bundle from Maven Central with:\n" + " curl -L -o ~/Downloads/lance-spark-bundle-3.5_2.12-0.4.0.jar \\\n" + " https://repo1.maven.org/maven2/com/lancedb/lance-spark-bundle-3.5_2.12/0.4.0/lance-spark-bundle-3.5_2.12-0.4.0.jar\n" + "or set LANCE_BUNDLE_JAR=/abs/path/to/jar." + ) + return f"{hudi_jar},{lance_jar}" + + +# ====================================================== +# 1. SPARK SESSION SETUP +# ====================================================== + +def create_spark() -> SparkSession: + jars = resolve_jars() + builder = ( + SparkSession.builder.appName("Hudi-Vector-Search-Batch-Demo") + .config("spark.jars", jars) + .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + .config( + "spark.sql.extensions", + "org.apache.spark.sql.hudi.HoodieSparkSessionExtension", + ) + .config( + "spark.sql.catalog.spark_catalog", + "org.apache.spark.sql.hudi.catalog.HoodieCatalog", + ) + .config("spark.sql.session.timeZone", "UTC") + .config("hoodie.read.blob.inline.mode", "CONTENT") + .config("spark.default.parallelism", "2") + .config("spark.sql.shuffle.partitions", "2") + ) + if CONFIG.get("hide_progress", True): + builder = builder.config("spark.ui.showConsoleProgress", "false") + spark = builder.getOrCreate() + spark.sparkContext.setLogLevel(CONFIG.get("log_level", "ERROR")) + return spark + + +# ====================================================== +# 2. LOAD DATASET (Oxford-IIIT Pet) +# ====================================================== + +def load_dataset(n_samples): + print(f"Loading dataset: Oxford-IIIT Pet ({n_samples} samples)...") + root = os.path.expanduser("~/.cache/torchvision") + ds = OxfordIIITPet(root=root, split="trainval", download=True) + class_names = ds.classes + + rng = np.random.default_rng() + n = min(n_samples, len(ds)) + indices = rng.choice(len(ds), size=n, replace=False) + + data = [] + for idx in indices: + img, label = ds[int(idx)] + img = img.convert("RGB") + bio = io.BytesIO() + img.save(bio, format="PNG") + img_bytes = bio.getvalue() + w, h = img.size + category = class_names[label] if isinstance(class_names, list) else str(label) + safe_category = category.replace("/", "_") + data.append( + { + "image_id": f"pets_{int(idx):06d}", + "category": category, + "category_sanitized": safe_category, + "label": int(label), + "description": f"{category} from Oxford-IIIT Pet", + "image_bytes_raw": img_bytes, + "width": int(w), + "height": int(h), + } + ) + print(f"✓ Loaded {len(data)} images") + return data, class_names + + +# ====================================================== +# 3. EMBEDDING MODEL (timm) +# ====================================================== + +def create_embedding_model(): + print(f"Loading embedding model: {CONFIG['embedding_model']}...") + model = timm.create_model(CONFIG["embedding_model"], pretrained=True, num_classes=0) + model.eval() + data_config = timm.data.resolve_model_data_config(model) + transform = timm.data.create_transform(**data_config, is_training=False) + print("✓ Model loaded") + return model, transform + + +def generate_embeddings(data, model, transform): + print(f"Generating embeddings for {len(data)} images...") + images = [] + for item in data: + img = Image.open(io.BytesIO(item["image_bytes_raw"])).convert("RGB") + images.append(transform(img)) + batch = torch.stack(images) + with torch.no_grad(): + feats = model(batch).detach().cpu().numpy() + feats = normalize(feats) + for i, item in enumerate(data): + item["embedding"] = feats[i].tolist() + print(f"✓ Generated embeddings (dimension: {feats.shape[1]})") + return data, int(feats.shape[1]) + + +# ====================================================== +# 4. SPLIT corpus + queries +# ====================================================== + +def split_corpus_and_queries(data): + """ + First N_CORPUS rows → corpus, last N_QUERIES rows → queries. Order is + already randomized by `load_dataset`'s `rng.choice(replace=False)`, so a + simple slice gives two disjoint subsets. + + We hold both subsets in Python so the oracle (Step 7) can recompute the + distance matrix from the exact same embeddings that were written to Hudi. + """ + n_corpus = CONFIG["n_corpus"] + n_queries = CONFIG["n_queries"] + if len(data) < n_corpus + n_queries: + sys.exit( + f"ERROR: requested n_corpus={n_corpus} + n_queries={n_queries}={n_corpus + n_queries} " + f"rows but dataset returned only {len(data)}" + ) + corpus = data[:n_corpus] + queries = data[n_corpus : n_corpus + n_queries] + + # Sanity: image_ids must be disjoint (no row appears in both tables). + corpus_ids = {r["image_id"] for r in corpus} + query_ids = {r["image_id"] for r in queries} + assert corpus_ids.isdisjoint(query_ids), "corpus and queries overlap" + print(f"✓ Split: {len(corpus)} corpus rows, {len(queries)} query rows (disjoint)") + return corpus, queries + + +# ====================================================== +# 5. STAGE → PARQUET → TEMP VIEW (PyArrow, bypassing PythonRDD) +# ====================================================== + +def stage_to_parquet_with_pyarrow(data, embedding_dim: int, staging_path: str) -> None: + arrow_schema = pa.schema( + [ + pa.field("image_id", pa.string(), nullable=False), + pa.field("category", pa.string(), nullable=False), + pa.field("category_sanitized", pa.string(), nullable=False), + pa.field("label", pa.int32(), nullable=False), + pa.field("description", pa.string(), nullable=True), + pa.field("image_bytes_raw", pa.binary(), nullable=False), + pa.field("width", pa.int32(), nullable=False), + pa.field("height", pa.int32(), nullable=False), + pa.field( + "embedding", + pa.list_( + pa.field("element", pa.float32(), nullable=False), + list_size=embedding_dim, + ), + nullable=False, + ), + ] + ) + columns = { + "image_id": [d["image_id"] for d in data], + "category": [d["category"] for d in data], + "category_sanitized": [d["category_sanitized"] for d in data], + "label": [int(d["label"]) for d in data], + "description": [d.get("description") for d in data], + "image_bytes_raw": [d["image_bytes_raw"] for d in data], + "width": [int(d["width"]) for d in data], + "height": [int(d["height"]) for d in data], + "embedding": [d["embedding"] for d in data], + } + pq.write_table(pa.table(columns, schema=arrow_schema), staging_path) + + +def register_staging_view(spark, data, embedding_dim, view_name, staging_path): + print(f"Staging Python data → Parquet at {staging_path} (PyArrow, no Spark)...") + stage_to_parquet_with_pyarrow(data, embedding_dim, staging_path) + spark.read.parquet(staging_path).createOrReplaceTempView(view_name) + print(f"✓ Registered Spark temp view: {view_name}") + + +# ====================================================== +# 6. CREATE TABLE + INSERT — SQL (run twice: corpus + queries) +# ====================================================== + +def create_hudi_table_sql(spark, embedding_dim, table_name, table_path): + print(f"\nDDL: CREATE TABLE {table_name} ... [{CONFIG['base_file_format']} base files]") + spark.sql(f"DROP TABLE IF EXISTS {table_name}") + ddl = f""" + CREATE TABLE {table_name} ( + image_id STRING, + category STRING, + category_sanitized STRING, + label INT, + description STRING, + image_bytes BLOB COMMENT 'Pet image bytes (INLINE)', + width INT, + height INT, + embedding VECTOR({embedding_dim}) + COMMENT 'Image embedding for ANN search' + ) USING hudi + PARTITIONED BY (category_sanitized) + LOCATION '{table_path}' + TBLPROPERTIES ( + primaryKey = 'image_id', + preCombineField = 'image_id', + type = 'cow', + 'hoodie.table.base.file.format' = '{CONFIG['base_file_format']}', + 'hoodie.write.record.merge.custom.implementation.classes' = 'org.apache.hudi.DefaultSparkRecordMerger' + ) + """ + spark.sql(ddl) + print(f"✓ Created table {table_name} at {table_path}") + + +def insert_into_hudi_sql(spark, table_name, staging_view): + print(f"\nDML: INSERT INTO {table_name} SELECT ... FROM {staging_view}") + insert = f""" + INSERT INTO {table_name} + SELECT + image_id, + category, + category_sanitized, + label, + description, + named_struct( + 'type', 'INLINE', + 'data', image_bytes_raw, + 'reference', cast(null as {BLOB_REFERENCE_CAST}) + ) AS image_bytes, + width, + height, + embedding + FROM {staging_view} + """ + spark.sql(insert) + count = spark.sql(f"SELECT COUNT(image_id) AS c FROM {table_name}").collect()[0]["c"] + print(f"✓ Inserted {count} records into {table_name}") + + +# ====================================================== +# 7. BATCH SEARCH — `hudi_vector_search_batch` TVF +# ====================================================== + +def run_batch_search_sql(spark): + """ + Run the batch TVF. Both tables share column names ({image_id, category, + category_sanitized, label, description, image_bytes, width, height, + embedding}), so every non-embedding column on the query side gets the + `_hudi_query_` prefix per HoodieVectorSearchPlanBuilder.scala (the + clashing-column rename path is automatically exercised). + """ + k = CONFIG["top_k"] + corpus = CONFIG["corpus_table_name"] + queries = CONFIG["queries_table_name"] + print(f"\nDQL: hudi_vector_search_batch(corpus={corpus}, queries={queries}, k={k}, cosine)") + sql = f""" + SELECT + image_id AS corpus_image_id, + category AS corpus_category, + image_bytes AS corpus_image_bytes, + _hudi_query_image_id AS query_image_id, + _hudi_query_category AS query_category, + _hudi_distance, + _hudi_query_index + FROM hudi_vector_search_batch( + '{corpus}', + 'embedding', + '{queries}', + 'embedding', + {k}, + 'cosine' + ) + ORDER BY _hudi_query_index, _hudi_distance + """ + print(sql.strip()) + rows = spark.sql(sql).collect() + print(f"✓ TVF returned {len(rows)} rows ({CONFIG['n_queries']} queries × {k} matches)") + return rows + + +# ====================================================== +# 8. ORACLE VALIDATION (numpy) — the certification step +# ====================================================== + +def oracle_validate_with_numpy(corpus, queries, tvf_rows, k): + """ + Re-derive top-k from the exact same Python-side embeddings and confirm the + TVF agrees. This is the load-bearing assertion of the script — if it + passes, batch mode is correct on this dataset; if it fails, the TVF and + numpy disagree and the script exits non-zero. + + Cosine distance via L2-normalized dot product: `dist = 1 - sim`. Embeddings + are already L2-normalized by `generate_embeddings` (sklearn.preprocessing + .normalize), so no re-normalization is needed here. + """ + print("\n" + "-" * 80) + print("ORACLE: numpy ground-truth vs TVF result") + print("-" * 80) + + corpus_embs = np.asarray([r["embedding"] for r in corpus], dtype=np.float32) Review Comment: The oracle reads `corpus`/`queries` from the in-memory Python list — the same objects that were written to Hudi. If Hudi's write or read path silently truncates, reorders, or precision-clips a VECTOR column, this assertion won't catch it: what's actually being certified is that the TVF's math agrees with numpy, not that Hudi correctly stored and returned the embeddings. For a `CERTIFIED ✓` claim against RFC-102, the oracle should rebuild `corpus_embs` / `query_embs` from a `SELECT image_id, embedding FROM <table>` read-back, so the full round-trip is in the assertion path. The single-query sibling demo likely has the same gap — worth fixing in both. ########## hudi-examples/hudi-examples-spark/src/test/python/vector_blob_demo/hudi_vector_search_batch_demo.py: ########## @@ -0,0 +1,713 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Hudi BATCH vector search demo — certifies `hudi_vector_search_batch` at non- +trivial scale against a numpy ground-truth oracle. + +Sibling to `hudi_sql_vector_blob_demo.py` (single-query). Same dataset, same +embedding model, same Hudi DDL — but the search step is the **table-to-table** +batch TVF described in RFC-102: + + SELECT * + FROM hudi_vector_search_batch( + 'pets_batch_corpus_<format>', 'embedding', + 'pets_batch_queries_<format>', 'embedding', + k, 'cosine') + +Flow: + 1. Load N_CORPUS + N_QUERIES Oxford-IIIT Pet images. + 2. Generate L2-normalized embeddings with `mobilenetv3_small_100`. + 3. Split into corpus (N_CORPUS) + held-out queries (N_QUERIES). + 4. Stage both via PyArrow, write each to its own Hudi table. + 5. Run `hudi_vector_search_batch` and collect results. + 6. **Oracle validation:** compute the cosine distance matrix in numpy from + the same embeddings and assert the TVF's top-k per query matches. + 7. Render a result panel (one row per query showing its top-k matches). + +Env vars: + HUDI_BUNDLE_JAR (defaults to ~/Downloads/hudi-spark3.5-bundle_2.12-1.2.0-rc1.jar) + HUDI_BASE_FILE_FORMAT (default 'lance'; set to 'parquet' to use Parquet) + LANCE_BUNDLE_JAR (defaults to ~/Downloads/lance-spark-bundle-3.5_2.12-0.4.0.jar; only used when HUDI_BASE_FILE_FORMAT=lance) + HUDI_BATCH_N_CORPUS (default 1000; rows in the corpus Hudi table) + HUDI_BATCH_N_QUERIES (default 20; rows in the query Hudi table) + HUDI_BATCH_TOP_K (default 5) + PYSPARK_DRIVER_MEMORY (default '4g') + HUDI_LANCE_DEMO_OUTDIR (default './outputs') +""" + +import io +import os +import shutil +import sys +from pathlib import Path + +# MUST run before any `pyspark` import — local-mode driver heap is fixed at +# JVM launch time and cannot be raised via SparkSession.config() later. +_driver_mem = os.getenv("PYSPARK_DRIVER_MEMORY", "4g") +os.environ.setdefault( + "PYSPARK_SUBMIT_ARGS", + f"--driver-memory {_driver_mem} --conf spark.driver.maxResultSize=2g pyspark-shell", +) + +import numpy as np +import pyarrow as pa +import pyarrow.parquet as pq +import torch +import timm +from sklearn.preprocessing import normalize +from PIL import Image + +import matplotlib + +matplotlib.use("Agg") +import matplotlib.pyplot as plt # noqa: E402 + +from torchvision.datasets import OxfordIIITPet # noqa: E402 + +from pyspark.sql import SparkSession + + +# ====================================================== +# CONFIGURATION +# ====================================================== + +_file_format = os.getenv("HUDI_BASE_FILE_FORMAT", "lance").lower() +if _file_format not in ("lance", "parquet"): + sys.exit(f"ERROR: HUDI_BASE_FILE_FORMAT must be 'lance' or 'parquet', got '{_file_format}'") + +CONFIG = { + "dataset": "OxfordIIITPet", + "base_file_format": _file_format, + "corpus_table_path": f"/tmp/hudi_batch_corpus_{_file_format}_pets", Review Comment: Hardcoded `/tmp/hudi_batch_corpus_{format}_pets` (and the matching staging Parquet paths below) will collide whenever two runs of this demo happen concurrently — parallel CI agents, or a user retrying while a previous run is still mid-flight. `wipe_prior_state()` will then `shutil.rmtree` the other run's data underneath it. Since the PR wires this into `run_demos.sh` (i.e. it will eventually run unattended), recommend either `tempfile.mkdtemp(prefix="hudi_batch_corpus_")` or a PID/timestamp suffix, with `HUDI_LANCE_DEMO_OUTDIR`-style env vars allowed to override. ########## hudi-examples/hudi-examples-spark/src/test/python/vector_blob_demo/hudi_vector_search_batch_demo.py: ########## @@ -0,0 +1,713 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Hudi BATCH vector search demo — certifies `hudi_vector_search_batch` at non- +trivial scale against a numpy ground-truth oracle. + +Sibling to `hudi_sql_vector_blob_demo.py` (single-query). Same dataset, same +embedding model, same Hudi DDL — but the search step is the **table-to-table** +batch TVF described in RFC-102: + + SELECT * + FROM hudi_vector_search_batch( + 'pets_batch_corpus_<format>', 'embedding', + 'pets_batch_queries_<format>', 'embedding', + k, 'cosine') + +Flow: + 1. Load N_CORPUS + N_QUERIES Oxford-IIIT Pet images. + 2. Generate L2-normalized embeddings with `mobilenetv3_small_100`. + 3. Split into corpus (N_CORPUS) + held-out queries (N_QUERIES). + 4. Stage both via PyArrow, write each to its own Hudi table. + 5. Run `hudi_vector_search_batch` and collect results. + 6. **Oracle validation:** compute the cosine distance matrix in numpy from + the same embeddings and assert the TVF's top-k per query matches. + 7. Render a result panel (one row per query showing its top-k matches). + +Env vars: + HUDI_BUNDLE_JAR (defaults to ~/Downloads/hudi-spark3.5-bundle_2.12-1.2.0-rc1.jar) + HUDI_BASE_FILE_FORMAT (default 'lance'; set to 'parquet' to use Parquet) + LANCE_BUNDLE_JAR (defaults to ~/Downloads/lance-spark-bundle-3.5_2.12-0.4.0.jar; only used when HUDI_BASE_FILE_FORMAT=lance) + HUDI_BATCH_N_CORPUS (default 1000; rows in the corpus Hudi table) + HUDI_BATCH_N_QUERIES (default 20; rows in the query Hudi table) + HUDI_BATCH_TOP_K (default 5) + PYSPARK_DRIVER_MEMORY (default '4g') + HUDI_LANCE_DEMO_OUTDIR (default './outputs') +""" + +import io +import os +import shutil +import sys +from pathlib import Path + +# MUST run before any `pyspark` import — local-mode driver heap is fixed at +# JVM launch time and cannot be raised via SparkSession.config() later. +_driver_mem = os.getenv("PYSPARK_DRIVER_MEMORY", "4g") +os.environ.setdefault( + "PYSPARK_SUBMIT_ARGS", + f"--driver-memory {_driver_mem} --conf spark.driver.maxResultSize=2g pyspark-shell", +) + +import numpy as np +import pyarrow as pa +import pyarrow.parquet as pq +import torch +import timm +from sklearn.preprocessing import normalize +from PIL import Image + +import matplotlib + +matplotlib.use("Agg") +import matplotlib.pyplot as plt # noqa: E402 + +from torchvision.datasets import OxfordIIITPet # noqa: E402 + +from pyspark.sql import SparkSession + + +# ====================================================== +# CONFIGURATION +# ====================================================== + +_file_format = os.getenv("HUDI_BASE_FILE_FORMAT", "lance").lower() +if _file_format not in ("lance", "parquet"): + sys.exit(f"ERROR: HUDI_BASE_FILE_FORMAT must be 'lance' or 'parquet', got '{_file_format}'") + +CONFIG = { + "dataset": "OxfordIIITPet", + "base_file_format": _file_format, + "corpus_table_path": f"/tmp/hudi_batch_corpus_{_file_format}_pets", + "corpus_table_name": f"pets_batch_corpus_{_file_format}", + "queries_table_path": f"/tmp/hudi_batch_queries_{_file_format}_pets", + "queries_table_name": f"pets_batch_queries_{_file_format}", + "n_corpus": int(os.getenv("HUDI_BATCH_N_CORPUS", "1000")), + "n_queries": int(os.getenv("HUDI_BATCH_N_QUERIES", "20")), + "top_k": int(os.getenv("HUDI_BATCH_TOP_K", "5")), + "embedding_model": "mobilenetv3_small_100", + "output_dir": os.getenv("HUDI_LANCE_DEMO_OUTDIR", "./outputs"), + "panel_filename": f"hudi_vector_search_batch_{_file_format}_results.png", + "log_level": "ERROR", + "hide_progress": True, + # Oracle tolerance: cosine distance computed on L2-normalized float32 vectors + # in numpy vs JVM-side DenseVector(Double) UDF. Float32 → Float64 widening + + # different summation orders allow ~1e-5 deltas. + "oracle_distance_tol": 1e-5, +} + +BLOB_REFERENCE_CAST = ( + "struct<external_path:string,offset:bigint,length:bigint,managed:boolean>" +) + + +# ====================================================== +# UTILITIES +# ====================================================== + +def ensure_dir(p: Path) -> None: + p.mkdir(parents=True, exist_ok=True) + + +def wipe_prior_state() -> None: + """ + Remove this script's prior table dirs and staging Parquets so re-runs are + idempotent. `DROP TABLE IF EXISTS` (run inside `create_hudi_table_sql`) only + removes the catalog entry — the data dir and `.hoodie/` timeline at + LOCATION persist, so a re-run would query stale rows alongside fresh ones + and the oracle would (correctly) flag the mismatch. + """ + targets = [ + CONFIG["corpus_table_path"], + CONFIG["queries_table_path"], + f"/tmp/staging_pets_batch_corpus_{CONFIG['base_file_format']}.parquet", + f"/tmp/staging_pets_batch_queries_{CONFIG['base_file_format']}.parquet", + ] + for t in targets: + p = Path(t) + if p.is_dir(): + shutil.rmtree(p, ignore_errors=True) + elif p.is_file(): + p.unlink(missing_ok=True) + # Catalog warehouse from prior runs in this cwd. + shutil.rmtree("spark-warehouse", ignore_errors=True) + + +def save_png_bytes(img_bytes: bytes, path: Path) -> None: + ensure_dir(path.parent) + with open(path, "wb") as f: + f.write(img_bytes) + + +def default_hudi_bundle_jar() -> str: + return str(Path.home() / "Downloads" / "hudi-spark3.5-bundle_2.12-1.2.0-rc1.jar") + + +def default_lance_bundle_jar() -> str: + return str(Path.home() / "Downloads" / "lance-spark-bundle-3.5_2.12-0.4.0.jar") + + +def resolve_jars() -> str: + hudi_jar = os.getenv("HUDI_BUNDLE_JAR", default_hudi_bundle_jar()) Review Comment: Defaulting to `~/Downloads/hudi-spark3.5-bundle_2.12-1.2.0-rc1.jar` and `~/Downloads/lance-spark-bundle-3.5_2.12-0.4.0.jar` works for the author but anyone trying to actually reproduce the certification will hit a `does not exist` error and bail. The PR description bills this as a "reproducible end-to-end certification" — for that to hold, JAR resolution should either: 1. Resolve from the in-tree Maven build output (e.g. `hudi-spark3.5-bundle/target/hudi-spark3.5-bundle_2.12-*.jar` via `find`), or 2. `mvn dependency:get` on the first run with a checksum check. Bonus: pinning to `1.2.0-rc1` will rot once rc1 → GA, so the version should also be parameterized. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
