This is an automated email from the ASF dual-hosted git repository.
damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 6b4dc555a24 Add spannerio vector writer. (#36654)
6b4dc555a24 is described below
commit 6b4dc555a245b4f82250677de21890955867314c
Author: claudevdm <[email protected]>
AuthorDate: Thu Oct 30 16:41:53 2025 -0400
Add spannerio vector writer. (#36654)
* draft
* Simplify
* Add xlang test markers.
* Fix tests.
* lints
* Remove extractor_fn args.
* Linter.
---------
Co-authored-by: Claude <[email protected]>
---
.../beam_PostCommit_Python_Xlang_Gcp_Direct.json | 2 +-
.../python/apache_beam/ml/rag/ingestion/spanner.py | 646 +++++++++++++++++++++
.../ml/rag/ingestion/spanner_it_test.py | 601 +++++++++++++++++++
3 files changed, 1248 insertions(+), 1 deletion(-)
diff --git a/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json
b/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json
index 2504db607e4..95fef3e26ca 100644
--- a/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json
+++ b/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json
@@ -1,4 +1,4 @@
{
"comment": "Modify this file in a trivial way to cause this test suite to
run",
- "modification": 12
+ "modification": 13
}
diff --git a/sdks/python/apache_beam/ml/rag/ingestion/spanner.py
b/sdks/python/apache_beam/ml/rag/ingestion/spanner.py
new file mode 100644
index 00000000000..f79db470bca
--- /dev/null
+++ b/sdks/python/apache_beam/ml/rag/ingestion/spanner.py
@@ -0,0 +1,646 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Cloud Spanner vector store writer for RAG pipelines.
+
+This module provides a writer for storing embeddings and associated metadata
+in Google Cloud Spanner. It supports flexible schema configuration with the
+ability to flatten metadata fields into dedicated columns.
+
+Example usage:
+
+ Default schema (id, embedding, content, metadata):
+ >>> config = SpannerVectorWriterConfig(
+ ... project_id="my-project",
+ ... instance_id="my-instance",
+ ... database_id="my-db",
+ ... table_name="embeddings"
+ ... )
+
+ Flattened metadata fields:
+ >>> specs = (
+ ... SpannerColumnSpecsBuilder()
+ ... .with_id_spec()
+ ... .with_embedding_spec()
+ ... .with_content_spec()
+ ... .add_metadata_field("source", str)
+ ... .add_metadata_field("page_number", int, default=0)
+ ... .with_metadata_spec()
+ ... .build()
+ ... )
+ >>> config = SpannerVectorWriterConfig(
+ ... project_id="my-project",
+ ... instance_id="my-instance",
+ ... database_id="my-db",
+ ... table_name="embeddings",
+ ... column_specs=specs
+ ... )
+
+Spanner schema example:
+
+ CREATE TABLE embeddings (
+ id STRING(1024) NOT NULL,
+ embedding ARRAY<FLOAT32>(vector_length=>768),
+ content STRING(MAX),
+ source STRING(MAX),
+ page_number INT64,
+ metadata JSON
+ ) PRIMARY KEY (id)
+"""
+
+import functools
+import json
+from dataclasses import dataclass
+from typing import Any
+from typing import Callable
+from typing import List
+from typing import Literal
+from typing import NamedTuple
+from typing import Optional
+from typing import Type
+
+import apache_beam as beam
+from apache_beam.coders import registry
+from apache_beam.coders.row_coder import RowCoder
+from apache_beam.io.gcp import spanner
+from apache_beam.ml.rag.ingestion.base import VectorDatabaseWriteConfig
+from apache_beam.ml.rag.types import Chunk
+
+
+@dataclass
+class SpannerColumnSpec:
+ """Column specification for Spanner vector writes.
+
+ Defines how to extract and format values from Chunks for insertion into
+ Spanner table columns. Each spec maps to one column in the target table.
+
+ Attributes:
+ column_name: Name of the Spanner table column
+ python_type: Python type for the NamedTuple field (required for RowCoder)
+ value_fn: Function to extract value from a Chunk
+
+ Examples:
+ String column:
+ >>> SpannerColumnSpec(
+ ... column_name="id",
+ ... python_type=str,
+ ... value_fn=lambda chunk: chunk.id
+ ... )
+
+ Array column with conversion:
+ >>> SpannerColumnSpec(
+ ... column_name="embedding",
+ ... python_type=List[float],
+ ... value_fn=lambda chunk: chunk.embedding.dense_embedding
+ ... )
+ """
+ column_name: str
+ python_type: Type
+ value_fn: Callable[[Chunk], Any]
+
+
+def _extract_and_convert(extract_fn, convert_fn, chunk):
+ if convert_fn:
+ return convert_fn(extract_fn(chunk))
+ return extract_fn(chunk)
+
+
+class SpannerColumnSpecsBuilder:
+ """Builder for creating Spanner column specifications.
+
+ Provides a fluent API for defining table schemas and how to populate them
+ from Chunk objects. Supports standard Chunk fields (id, embedding, content,
+ metadata) and flattening metadata fields into dedicated columns.
+
+ Example:
+ >>> specs = (
+ ... SpannerColumnSpecsBuilder()
+ ... .with_id_spec()
+ ... .with_embedding_spec()
+ ... .with_content_spec()
+ ... .add_metadata_field("source", str)
+ ... .with_metadata_spec()
+ ... .build()
+ ... )
+ """
+ def __init__(self):
+ self._specs: List[SpannerColumnSpec] = []
+
+ @staticmethod
+ def with_defaults() -> 'SpannerColumnSpecsBuilder':
+ """Create builder with default schema.
+
+ Default schema includes:
+ - id (STRING): Chunk ID
+ - embedding (ARRAY<FLOAT32>): Dense embedding vector
+ - content (STRING): Chunk content text
+ - metadata (JSON): Full metadata as JSON
+
+ Returns:
+ Builder with default column specifications
+ """
+ return (
+ SpannerColumnSpecsBuilder().with_id_spec().with_embedding_spec().
+ with_content_spec().with_metadata_spec())
+
+ def with_id_spec(
+ self,
+ column_name: str = "id",
+ python_type: Type = str,
+ convert_fn: Optional[Callable[[str], Any]] = None
+ ) -> 'SpannerColumnSpecsBuilder':
+ """Add ID column specification.
+
+ Args:
+ column_name: Column name (default: "id")
+ python_type: Python type (default: str)
+ convert_fn: Optional converter (e.g., to cast to int)
+
+ Returns:
+ Self for method chaining
+
+ Examples:
+ Default string ID:
+ >>> builder.with_id_spec()
+
+ Integer ID with conversion:
+ >>> builder.with_id_spec(
+ ... python_type=int,
+ ... convert_fn=lambda id: int(id.split('_')[1])
+ ... )
+ """
+
+ self._specs.append(
+ SpannerColumnSpec(
+ column_name=column_name,
+ python_type=python_type,
+ value_fn=functools.partial(
+ _extract_and_convert, lambda chunk: chunk.id, convert_fn)))
+ return self
+
+ def with_embedding_spec(
+ self,
+ column_name: str = "embedding",
+ convert_fn: Optional[Callable[[List[float]], List[float]]] = None
+ ) -> 'SpannerColumnSpecsBuilder':
+ """Add embedding array column (ARRAY<FLOAT32> or ARRAY<FLOAT64>).
+
+ Args:
+ column_name: Column name (default: "embedding")
+ convert_fn: Optional converter (e.g., normalize, quantize)
+
+ Returns:
+ Self for method chaining
+
+ Examples:
+ Default embedding:
+ >>> builder.with_embedding_spec()
+
+ Normalized embedding:
+ >>> def normalize(vec):
+ ... norm = (sum(x**2 for x in vec) ** 0.5) or 1.0
+ ... return [x/norm for x in vec]
+ >>> builder.with_embedding_spec(convert_fn=normalize)
+
+ Rounded precision:
+ >>> builder.with_embedding_spec(
+ ... convert_fn=lambda vec: [round(x, 4) for x in vec]
+ ... )
+ """
+ def extract_fn(chunk: Chunk) -> List[float]:
+ if chunk.embedding is None or chunk.embedding.dense_embedding is None:
+ raise ValueError(f'Chunk must contain embedding: {chunk}')
+ return chunk.embedding.dense_embedding
+
+ self._specs.append(
+ SpannerColumnSpec(
+ column_name=column_name,
+ python_type=List[float],
+ value_fn=functools.partial(
+ _extract_and_convert, extract_fn, convert_fn)))
+ return self
+
+ def with_content_spec(
+ self,
+ column_name: str = "content",
+ python_type: Type = str,
+ convert_fn: Optional[Callable[[str], Any]] = None
+ ) -> 'SpannerColumnSpecsBuilder':
+ """Add content column.
+
+ Args:
+ column_name: Column name (default: "content")
+ python_type: Python type (default: str)
+ convert_fn: Optional converter
+
+ Returns:
+ Self for method chaining
+
+ Examples:
+ Default text content:
+ >>> builder.with_content_spec()
+
+ Content length as integer:
+ >>> builder.with_content_spec(
+ ... column_name="content_length",
+ ... python_type=int,
+ ... convert_fn=lambda text: len(text.split())
+ ... )
+
+ Truncated content:
+ >>> builder.with_content_spec(
+ ... convert_fn=lambda text: text[:1000]
+ ... )
+ """
+ def extract_fn(chunk: Chunk) -> str:
+ if chunk.content.text is None:
+ raise ValueError(f'Chunk must contain content: {chunk}')
+ return chunk.content.text
+
+ self._specs.append(
+ SpannerColumnSpec(
+ column_name=column_name,
+ python_type=python_type,
+ value_fn=functools.partial(
+ _extract_and_convert, extract_fn, convert_fn)))
+ return self
+
+ def with_metadata_spec(
+ self, column_name: str = "metadata") -> 'SpannerColumnSpecsBuilder':
+ """Add metadata JSON column.
+
+ Stores the full metadata dictionary as a JSON string in Spanner.
+
+ Args:
+ column_name: Column name (default: "metadata")
+
+ Returns:
+ Self for method chaining
+
+ Note:
+ Metadata is automatically converted to JSON string using json.dumps()
+ """
+ value_fn = lambda chunk: json.dumps(chunk.metadata)
+ self._specs.append(
+ SpannerColumnSpec(
+ column_name=column_name, python_type=str, value_fn=value_fn))
+ return self
+
+ def add_metadata_field(
+ self,
+ field: str,
+ python_type: Type,
+ column_name: Optional[str] = None,
+ convert_fn: Optional[Callable[[Any], Any]] = None,
+ default: Any = None) -> 'SpannerColumnSpecsBuilder':
+ """Flatten a metadata field into its own column.
+
+ Extracts a specific field from chunk.metadata and stores it in a
+ dedicated table column.
+
+ Args:
+ field: Key in chunk.metadata to extract
+ python_type: Python type (must be explicitly specified)
+ column_name: Column name (default: same as field)
+ convert_fn: Optional converter for type casting/transformation
+ default: Default value if field is missing from metadata
+
+ Returns:
+ Self for method chaining
+
+ Examples:
+ String field:
+ >>> builder.add_metadata_field("source", str)
+
+ Integer with default:
+ >>> builder.add_metadata_field(
+ ... "page_number",
+ ... int,
+ ... default=0
+ ... )
+
+ Float with conversion:
+ >>> builder.add_metadata_field(
+ ... "confidence",
+ ... float,
+ ... convert_fn=lambda x: round(float(x), 2),
+ ... default=0.0
+ ... )
+
+ List of strings:
+ >>> builder.add_metadata_field(
+ ... "tags",
+ ... List[str],
+ ... default=[]
+ ... )
+
+ Timestamp with conversion:
+ >>> builder.add_metadata_field(
+ ... "created_at",
+ ... str,
+ ... convert_fn=lambda ts: ts.isoformat()
+ ... )
+ """
+ name = column_name or field
+
+ def value_fn(chunk: Chunk) -> Any:
+ return chunk.metadata.get(field, default)
+
+ self._specs.append(
+ SpannerColumnSpec(
+ column_name=name,
+ python_type=python_type,
+ value_fn=functools.partial(
+ _extract_and_convert, value_fn, convert_fn)))
+ return self
+
+ def add_column(
+ self,
+ column_name: str,
+ python_type: Type,
+ value_fn: Callable[[Chunk], Any]) -> 'SpannerColumnSpecsBuilder':
+ """Add a custom column with full control.
+
+ Args:
+ column_name: Column name
+ python_type: Python type (required)
+ value_fn: Value extraction function
+
+ Returns:
+ Self for method chaining
+
+ Examples:
+ Boolean flag:
+ >>> builder.add_column(
+ ... column_name="has_code",
+ ... python_type=bool,
+ ... value_fn=lambda chunk: "```" in chunk.content.text
+ ... )
+
+ Computed value:
+ >>> builder.add_column(
+ ... column_name="word_count",
+ ... python_type=int,
+ ... value_fn=lambda chunk: len(chunk.content.text.split())
+ ... )
+ """
+ self._specs.append(
+ SpannerColumnSpec(
+ column_name=column_name, python_type=python_type,
+ value_fn=value_fn))
+ return self
+
+ def build(self) -> List[SpannerColumnSpec]:
+ """Build the final list of column specifications.
+
+ Returns:
+ List of SpannerColumnSpec objects
+ """
+ return self._specs.copy()
+
+
+class _SpannerSchemaBuilder:
+ """Internal: Builds NamedTuple schema and registers RowCoder.
+
+ Creates a NamedTuple type from column specifications and registers it
+ with Beam's RowCoder for serialization.
+ """
+ def __init__(self, table_name: str, column_specs: List[SpannerColumnSpec]):
+ """Initialize schema builder.
+
+ Args:
+ table_name: Table name (used in NamedTuple type name)
+ column_specs: List of column specifications
+
+ Raises:
+ ValueError: If duplicate column names are found
+ """
+ self.table_name = table_name
+ self.column_specs = column_specs
+
+ # Validate no duplicates
+ names = [col.column_name for col in column_specs]
+ duplicates = set(name for name in names if names.count(name) > 1)
+ if duplicates:
+ raise ValueError(f"Duplicate column names: {duplicates}")
+
+ # Create NamedTuple type
+ fields = [(col.column_name, col.python_type) for col in column_specs]
+ type_name = f"SpannerVectorRecord_{table_name}"
+ self.record_type = NamedTuple(type_name, fields) # type: ignore
+
+ # Register coder
+ registry.register_coder(self.record_type, RowCoder)
+
+ def create_converter(self) -> Callable[[Chunk], NamedTuple]:
+ """Create converter function from Chunk to NamedTuple record.
+
+ Returns:
+ Function that converts a Chunk to a NamedTuple record
+ """
+ def convert(chunk: Chunk) -> self.record_type: # type: ignore
+ values = {
+ col.column_name: col.value_fn(chunk)
+ for col in self.column_specs
+ }
+ return self.record_type(**values) # type: ignore
+
+ return convert
+
+
+class SpannerVectorWriterConfig(VectorDatabaseWriteConfig):
+ """Configuration for writing vectors to Cloud Spanner.
+
+ Supports flexible schema configuration through column specifications and
+ provides control over Spanner-specific write parameters.
+
+ Examples:
+ Default schema:
+ >>> config = SpannerVectorWriterConfig(
+ ... project_id="my-project",
+ ... instance_id="my-instance",
+ ... database_id="my-db",
+ ... table_name="embeddings"
+ ... )
+
+ Custom schema with flattened metadata:
+ >>> specs = (
+ ... SpannerColumnSpecsBuilder()
+ ... .with_id_spec()
+ ... .with_embedding_spec()
+ ... .with_content_spec()
+ ... .add_metadata_field("source", str)
+ ... .add_metadata_field("page_number", int, default=0)
+ ... .with_metadata_spec()
+ ... .build()
+ ... )
+ >>> config = SpannerVectorWriterConfig(
+ ... project_id="my-project",
+ ... instance_id="my-instance",
+ ... database_id="my-db",
+ ... table_name="embeddings",
+ ... column_specs=specs
+ ... )
+
+ With emulator:
+ >>> config = SpannerVectorWriterConfig(
+ ... project_id="test-project",
+ ... instance_id="test-instance",
+ ... database_id="test-db",
+ ... table_name="embeddings",
+ ... emulator_host="http://localhost:9010"
+ ... )
+ """
+ def __init__(
+ self,
+ project_id: str,
+ instance_id: str,
+ database_id: str,
+ table_name: str,
+ *,
+ # Schema configuration
+ column_specs: Optional[List[SpannerColumnSpec]] = None,
+ # Write operation type
+ write_mode: Literal["INSERT", "UPDATE", "REPLACE",
+ "INSERT_OR_UPDATE"] = "INSERT_OR_UPDATE",
+ # Batching configuration
+ max_batch_size_bytes: Optional[int] = None,
+ max_number_mutations: Optional[int] = None,
+ max_number_rows: Optional[int] = None,
+ grouping_factor: Optional[int] = None,
+ # Networking
+ host: Optional[str] = None,
+ emulator_host: Optional[str] = None,
+ expansion_service: Optional[str] = None,
+ # Retry/deadline configuration
+ commit_deadline: Optional[int] = None,
+ max_cumulative_backoff: Optional[int] = None,
+ # Error handling
+ failure_mode: Optional[
+ spanner.FailureMode] = spanner.FailureMode.REPORT_FAILURES,
+ high_priority: bool = False,
+ # Additional Spanner arguments
+ **spanner_kwargs):
+ """Initialize Spanner vector writer configuration.
+
+ Args:
+ project_id: GCP project ID
+ instance_id: Spanner instance ID
+ database_id: Spanner database ID
+ table_name: Target table name
+ column_specs: Schema configuration using SpannerColumnSpecsBuilder.
+ If None, uses default schema (id, embedding, content, metadata)
+ write_mode: Spanner write operation type:
+ - INSERT: Fail if row exists
+ - UPDATE: Fail if row doesn't exist
+ - REPLACE: Delete then insert
+ - INSERT_OR_UPDATE: Insert or update if exists (default)
+ max_batch_size_bytes: Maximum bytes per mutation batch (default: 1MB)
+ max_number_mutations: Maximum cell mutations per batch (default: 5000)
+ max_number_rows: Maximum rows per batch (default: 500)
+ grouping_factor: Multiple of max mutation for sorting (default: 1000)
+ host: Spanner host URL (usually not needed)
+ emulator_host: Spanner emulator host (e.g., "http://localhost:9010")
+ expansion_service: Java expansion service address (host:port)
+ commit_deadline: Commit API deadline in seconds (default: 15)
+ max_cumulative_backoff: Max retry backoff seconds (default: 900)
+ failure_mode: Error handling strategy:
+ - FAIL_FAST: Throw exception for any failure
+ - REPORT_FAILURES: Continue processing (default)
+ high_priority: Use high priority for operations (default: False)
+ **spanner_kwargs: Additional keyword arguments to pass to the
+ underlying Spanner write transform. Use this to pass any
+ Spanner-specific parameters not explicitly exposed by this config.
+ """
+ self.project_id = project_id
+ self.instance_id = instance_id
+ self.database_id = database_id
+ self.table_name = table_name
+ self.write_mode = write_mode
+ self.max_batch_size_bytes = max_batch_size_bytes
+ self.max_number_mutations = max_number_mutations
+ self.max_number_rows = max_number_rows
+ self.grouping_factor = grouping_factor
+ self.host = host
+ self.emulator_host = emulator_host
+ self.expansion_service = expansion_service
+ self.commit_deadline = commit_deadline
+ self.max_cumulative_backoff = max_cumulative_backoff
+ self.failure_mode = failure_mode
+ self.high_priority = high_priority
+ self.spanner_kwargs = spanner_kwargs
+
+ # Use defaults if not provided
+ specs = column_specs or SpannerColumnSpecsBuilder.with_defaults().build()
+
+ # Create schema builder (NamedTuple + RowCoder registration)
+ self.schema_builder = _SpannerSchemaBuilder(table_name, specs)
+
+ def create_write_transform(self) -> beam.PTransform:
+ """Create the Spanner write PTransform.
+
+ Returns:
+ PTransform for writing to Spanner
+ """
+ return _WriteToSpannerVectorDatabase(self)
+
+
+class _WriteToSpannerVectorDatabase(beam.PTransform):
+ """Internal: PTransform for writing to Spanner vector database."""
+ def __init__(self, config: SpannerVectorWriterConfig):
+ """Initialize write transform.
+
+ Args:
+ config: Spanner writer configuration
+ """
+ self.config = config
+ self.schema_builder = config.schema_builder
+
+ def expand(self, pcoll: beam.PCollection[Chunk]):
+ """Expand the transform.
+
+ Args:
+ pcoll: PCollection of Chunks to write
+ """
+ # Select appropriate Spanner write transform based on write_mode
+ write_transform_class = {
+ "INSERT": spanner.SpannerInsert,
+ "UPDATE": spanner.SpannerUpdate,
+ "REPLACE": spanner.SpannerReplace,
+ "INSERT_OR_UPDATE": spanner.SpannerInsertOrUpdate,
+ }[self.config.write_mode]
+
+ return (
+ pcoll
+ | "Convert to Records" >> beam.Map(
+ self.schema_builder.create_converter()).with_output_types(
+ self.schema_builder.record_type)
+ | "Write to Spanner" >> write_transform_class(
+ project_id=self.config.project_id,
+ instance_id=self.config.instance_id,
+ database_id=self.config.database_id,
+ table=self.config.table_name,
+ max_batch_size_bytes=self.config.max_batch_size_bytes,
+ max_number_mutations=self.config.max_number_mutations,
+ max_number_rows=self.config.max_number_rows,
+ grouping_factor=self.config.grouping_factor,
+ host=self.config.host,
+ emulator_host=self.config.emulator_host,
+ commit_deadline=self.config.commit_deadline,
+ max_cumulative_backoff=self.config.max_cumulative_backoff,
+ failure_mode=self.config.failure_mode,
+ expansion_service=self.config.expansion_service,
+ high_priority=self.config.high_priority,
+ **self.config.spanner_kwargs))
diff --git a/sdks/python/apache_beam/ml/rag/ingestion/spanner_it_test.py
b/sdks/python/apache_beam/ml/rag/ingestion/spanner_it_test.py
new file mode 100644
index 00000000000..ab9a982a81f
--- /dev/null
+++ b/sdks/python/apache_beam/ml/rag/ingestion/spanner_it_test.py
@@ -0,0 +1,601 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Integration tests for Spanner vector store writer."""
+
+import logging
+import os
+import time
+import unittest
+import uuid
+
+import pytest
+
+import apache_beam as beam
+from apache_beam.ml.rag.ingestion.spanner import SpannerVectorWriterConfig
+from apache_beam.ml.rag.types import Chunk
+from apache_beam.ml.rag.types import Content
+from apache_beam.ml.rag.types import Embedding
+from apache_beam.testing.test_pipeline import TestPipeline
+
+# pylint: disable=wrong-import-order, wrong-import-position
+try:
+ from google.cloud import spanner
+except ImportError:
+ spanner = None
+
+try:
+ from testcontainers.core.container import DockerContainer
+except ImportError:
+ DockerContainer = None
+# pylint: enable=wrong-import-order, wrong-import-position
+
+
+def retry(fn, retries, err_msg, *args, **kwargs):
+ """Retry a function with exponential backoff."""
+ for _ in range(retries):
+ try:
+ return fn(*args, **kwargs)
+ except: # pylint: disable=bare-except
+ time.sleep(1)
+ logging.error(err_msg)
+ raise RuntimeError(err_msg)
+
+
+class SpannerEmulatorHelper:
+ """Helper for managing Spanner emulator lifecycle."""
+ def __init__(self, project_id: str, instance_id: str, table_name: str):
+ self.project_id = project_id
+ self.instance_id = instance_id
+ self.table_name = table_name
+ self.host = None
+
+ # Start emulator
+ self.emulator = DockerContainer(
+ 'gcr.io/cloud-spanner-emulator/emulator:latest').with_exposed_ports(
+ 9010, 9020)
+ retry(self.emulator.start, 3, 'Could not start spanner emulator.')
+ time.sleep(3)
+
+ self.host = f'{self.emulator.get_container_host_ip()}:' \
+ f'{self.emulator.get_exposed_port(9010)}'
+ os.environ['SPANNER_EMULATOR_HOST'] = self.host
+
+ # Create client and instance
+ self.client = spanner.Client(project_id)
+ self.instance = self.client.instance(instance_id)
+ self.create_instance()
+
+ def create_instance(self):
+ """Create Spanner instance in emulator."""
+ self.instance.create().result(120)
+
+ def create_database(self, database_id: str):
+ """Create database with default vector table schema."""
+ database = self.instance.database(
+ database_id,
+ ddl_statements=[
+ f'''
+ CREATE TABLE {self.table_name} (
+ id STRING(1024) NOT NULL,
+ embedding ARRAY<FLOAT32>(vector_length=>3),
+ content STRING(MAX),
+ metadata JSON
+ ) PRIMARY KEY (id)'''
+ ])
+ database.create().result(120)
+
+ def read_data(self, database_id: str):
+ """Read all data from the table."""
+ database = self.instance.database(database_id)
+ with database.snapshot() as snapshot:
+ results = snapshot.execute_sql(
+ f'SELECT * FROM {self.table_name} ORDER BY id')
+ return list(results) if results else []
+
+ def drop_database(self, database_id: str):
+ """Drop the database."""
+ database = self.instance.database(database_id)
+ database.drop()
+
+ def shutdown(self):
+ """Stop the emulator."""
+ if self.emulator:
+ try:
+ self.emulator.stop()
+ except: # pylint: disable=bare-except
+ logging.error('Could not stop Spanner emulator.')
+
+ def get_emulator_host(self) -> str:
+ """Get the emulator host URL."""
+ return f'http://{self.host}'
+
+
[email protected]_gcp_java_expansion_service
[email protected](
+ os.environ.get('EXPANSION_JARS'),
+ "EXPANSION_JARS environment var is not provided, "
+ "indicating that jars have not been built")
[email protected](spanner is None, 'GCP dependencies are not installed.')
[email protected](
+ DockerContainer is None, 'testcontainers package is not installed.')
+class SpannerVectorWriterTest(unittest.TestCase):
+ """Integration tests for Spanner vector writer."""
+ @classmethod
+ def setUpClass(cls):
+ """Set up Spanner emulator for all tests."""
+ cls.project_id = 'test-project'
+ cls.instance_id = 'test-instance'
+ cls.table_name = 'embeddings'
+
+ cls.spanner_helper = SpannerEmulatorHelper(
+ cls.project_id, cls.instance_id, cls.table_name)
+
+ @classmethod
+ def tearDownClass(cls):
+ """Tear down Spanner emulator."""
+ cls.spanner_helper.shutdown()
+
+ def setUp(self):
+ """Create a unique database for each test."""
+ self.database_id = f'test_db_{uuid.uuid4().hex}'[:30]
+ self.spanner_helper.create_database(self.database_id)
+
+ def tearDown(self):
+ """Drop the test database."""
+ self.spanner_helper.drop_database(self.database_id)
+
+ def test_write_default_schema(self):
+ """Test writing with default schema (id, embedding, content, metadata)."""
+ # Create test chunks
+ chunks = [
+ Chunk(
+ id='doc1',
+ embedding=Embedding(dense_embedding=[1.0, 2.0, 3.0]),
+ content=Content(text='First document'),
+ metadata={
+ 'source': 'test', 'page': 1
+ }),
+ Chunk(
+ id='doc2',
+ embedding=Embedding(dense_embedding=[4.0, 5.0, 6.0]),
+ content=Content(text='Second document'),
+ metadata={
+ 'source': 'test', 'page': 2
+ }),
+ ]
+
+ # Create config with default schema
+ config = SpannerVectorWriterConfig(
+ project_id=self.project_id,
+ instance_id=self.instance_id,
+ database_id=self.database_id,
+ table_name=self.table_name,
+ emulator_host=self.spanner_helper.get_emulator_host(),
+ )
+
+ # Write chunks
+ with TestPipeline() as p:
+ p.not_use_test_runner_api = True
+ _ = (p | beam.Create(chunks) | config.create_write_transform())
+
+ # Verify data was written
+ results = self.spanner_helper.read_data(self.database_id)
+ self.assertEqual(len(results), 2)
+
+ # Check first row
+ row1 = results[0]
+ self.assertEqual(row1[0], 'doc1') # id
+ self.assertEqual(list(row1[1]), [1.0, 2.0, 3.0]) # embedding
+ self.assertEqual(row1[2], 'First document') # content
+ # metadata is JSON
+ metadata1 = row1[3]
+ self.assertEqual(metadata1['source'], 'test')
+ self.assertEqual(metadata1['page'], 1)
+
+ # Check second row
+ row2 = results[1]
+ self.assertEqual(row2[0], 'doc2')
+ self.assertEqual(list(row2[1]), [4.0, 5.0, 6.0])
+ self.assertEqual(row2[2], 'Second document')
+
+ def test_write_flattened_metadata(self):
+ """Test writing with flattened metadata fields."""
+ from apache_beam.ml.rag.ingestion.spanner import SpannerColumnSpecsBuilder
+
+ # Create custom database with flattened columns
+ self.spanner_helper.drop_database(self.database_id)
+ database = self.spanner_helper.instance.database(
+ self.database_id,
+ ddl_statements=[
+ f'''
+ CREATE TABLE {self.table_name} (
+ id STRING(1024) NOT NULL,
+ embedding ARRAY<FLOAT32>(vector_length=>3),
+ content STRING(MAX),
+ source STRING(MAX),
+ page_number INT64,
+ metadata JSON
+ ) PRIMARY KEY (id)'''
+ ])
+ database.create().result(120)
+
+ # Create test chunks
+ chunks = [
+ Chunk(
+ id='doc1',
+ embedding=Embedding(dense_embedding=[1.0, 2.0, 3.0]),
+ content=Content(text='First document'),
+ metadata={
+ 'source': 'book.pdf', 'page': 10, 'author': 'John'
+ }),
+ Chunk(
+ id='doc2',
+ embedding=Embedding(dense_embedding=[4.0, 5.0, 6.0]),
+ content=Content(text='Second document'),
+ metadata={
+ 'source': 'article.txt', 'page': 5, 'author': 'Jane'
+ }),
+ ]
+
+ # Create config with flattened metadata
+ specs = (
+ SpannerColumnSpecsBuilder().with_id_spec().with_embedding_spec().
+ with_content_spec().add_metadata_field(
+ 'source', str, column_name='source').add_metadata_field(
+ 'page', int,
+ column_name='page_number').with_metadata_spec().build())
+
+ config = SpannerVectorWriterConfig(
+ project_id=self.project_id,
+ instance_id=self.instance_id,
+ database_id=self.database_id,
+ table_name=self.table_name,
+ column_specs=specs,
+ emulator_host=self.spanner_helper.get_emulator_host(),
+ )
+
+ # Write chunks
+ with TestPipeline() as p:
+ p.not_use_test_runner_api = True
+ _ = (p | beam.Create(chunks) | config.create_write_transform())
+
+ # Verify data
+ database = self.spanner_helper.instance.database(self.database_id)
+ with database.snapshot() as snapshot:
+ results = snapshot.execute_sql(
+ f'SELECT id, embedding, content, source, page_number, metadata '
+ f'FROM {self.table_name} ORDER BY id')
+ rows = list(results)
+
+ self.assertEqual(len(rows), 2)
+
+ # Check first row
+ self.assertEqual(rows[0][0], 'doc1')
+ self.assertEqual(list(rows[0][1]), [1.0, 2.0, 3.0])
+ self.assertEqual(rows[0][2], 'First document')
+ self.assertEqual(rows[0][3], 'book.pdf') # flattened source
+ self.assertEqual(rows[0][4], 10) # flattened page_number
+
+ metadata1 = rows[0][5]
+ self.assertEqual(metadata1['author'], 'John')
+
+ def test_write_minimal_schema(self):
+ """Test writing with minimal schema (only id and embedding)."""
+ from apache_beam.ml.rag.ingestion.spanner import SpannerColumnSpecsBuilder
+
+ # Create custom database with minimal schema
+ self.spanner_helper.drop_database(self.database_id)
+ database = self.spanner_helper.instance.database(
+ self.database_id,
+ ddl_statements=[
+ f'''
+ CREATE TABLE {self.table_name} (
+ id STRING(1024) NOT NULL,
+ embedding ARRAY<FLOAT32>(vector_length=>3)
+ ) PRIMARY KEY (id)'''
+ ])
+ database.create().result(120)
+
+ # Create test chunks
+ chunks = [
+ Chunk(
+ id='doc1',
+ embedding=Embedding(dense_embedding=[1.0, 2.0, 3.0]),
+ content=Content(text='First document'),
+ metadata={'source': 'test'}),
+ Chunk(
+ id='doc2',
+ embedding=Embedding(dense_embedding=[4.0, 5.0, 6.0]),
+ content=Content(text='Second document'),
+ metadata={'source': 'test'}),
+ ]
+
+ # Create config with minimal schema
+ specs = (
+ SpannerColumnSpecsBuilder().with_id_spec().with_embedding_spec().build(
+ ))
+
+ config = SpannerVectorWriterConfig(
+ project_id=self.project_id,
+ instance_id=self.instance_id,
+ database_id=self.database_id,
+ table_name=self.table_name,
+ column_specs=specs,
+ emulator_host=self.spanner_helper.get_emulator_host(),
+ )
+
+ # Write chunks
+ with TestPipeline() as p:
+ p.not_use_test_runner_api = True
+ _ = (p | beam.Create(chunks) | config.create_write_transform())
+
+ # Verify data
+ results = self.spanner_helper.read_data(self.database_id)
+ self.assertEqual(len(results), 2)
+ self.assertEqual(results[0][0], 'doc1')
+ self.assertEqual(list(results[0][1]), [1.0, 2.0, 3.0])
+
+ def test_write_with_converter(self):
+ """Test writing with custom converter function."""
+ from apache_beam.ml.rag.ingestion.spanner import SpannerColumnSpecsBuilder
+
+ # Create test chunks with embeddings that need normalization
+ chunks = [
+ Chunk(
+ id='doc1',
+ embedding=Embedding(dense_embedding=[3.0, 4.0, 0.0]),
+ content=Content(text='First document'),
+ metadata={'source': 'test'}),
+ ]
+
+ # Define normalizer
+ def normalize(vec):
+ norm = (sum(x**2 for x in vec)**0.5) or 1.0
+ return [x / norm for x in vec]
+
+ # Create config with normalized embeddings
+ specs = (
+ SpannerColumnSpecsBuilder().with_id_spec().with_embedding_spec(
+ convert_fn=normalize).with_content_spec().with_metadata_spec().
+ build())
+
+ config = SpannerVectorWriterConfig(
+ project_id=self.project_id,
+ instance_id=self.instance_id,
+ database_id=self.database_id,
+ table_name=self.table_name,
+ column_specs=specs,
+ emulator_host=self.spanner_helper.get_emulator_host(),
+ )
+
+ # Write chunks
+ with TestPipeline() as p:
+ p.not_use_test_runner_api = True
+ _ = (p | beam.Create(chunks) | config.create_write_transform())
+
+ # Verify data - embedding should be normalized
+ results = self.spanner_helper.read_data(self.database_id)
+ self.assertEqual(len(results), 1)
+
+ embedding = list(results[0][1])
+ # Original was [3.0, 4.0, 0.0], normalized should be [0.6, 0.8, 0.0]
+ self.assertAlmostEqual(embedding[0], 0.6, places=5)
+ self.assertAlmostEqual(embedding[1], 0.8, places=5)
+ self.assertAlmostEqual(embedding[2], 0.0, places=5)
+
+ # Check norm is 1.0
+ norm = sum(x**2 for x in embedding)**0.5
+ self.assertAlmostEqual(norm, 1.0, places=5)
+
+ def test_write_update_mode(self):
+ """Test writing with UPDATE mode."""
+ # First insert data
+ chunks_insert = [
+ Chunk(
+ id='doc1',
+ embedding=Embedding(dense_embedding=[1.0, 2.0, 3.0]),
+ content=Content(text='Original content'),
+ metadata={'version': 1}),
+ ]
+
+ config_insert = SpannerVectorWriterConfig(
+ project_id=self.project_id,
+ instance_id=self.instance_id,
+ database_id=self.database_id,
+ table_name=self.table_name,
+ write_mode='INSERT',
+ emulator_host=self.spanner_helper.get_emulator_host(),
+ )
+
+ with TestPipeline() as p:
+ p.not_use_test_runner_api = True
+ _ = (
+ p
+ | beam.Create(chunks_insert)
+ | config_insert.create_write_transform())
+
+ # Update existing row
+ chunks_update = [
+ Chunk(
+ id='doc1',
+ embedding=Embedding(dense_embedding=[4.0, 5.0, 6.0]),
+ content=Content(text='Updated content'),
+ metadata={'version': 2}),
+ ]
+
+ config_update = SpannerVectorWriterConfig(
+ project_id=self.project_id,
+ instance_id=self.instance_id,
+ database_id=self.database_id,
+ table_name=self.table_name,
+ write_mode='UPDATE',
+ emulator_host=self.spanner_helper.get_emulator_host(),
+ )
+
+ with TestPipeline() as p:
+ p.not_use_test_runner_api = True
+ _ = (
+ p
+ | beam.Create(chunks_update)
+ | config_update.create_write_transform())
+
+ # Verify update succeeded
+ results = self.spanner_helper.read_data(self.database_id)
+ self.assertEqual(len(results), 1)
+ self.assertEqual(results[0][0], 'doc1')
+ self.assertEqual(list(results[0][1]), [4.0, 5.0, 6.0])
+ self.assertEqual(results[0][2], 'Updated content')
+
+ metadata = results[0][3]
+ self.assertEqual(metadata['version'], 2)
+
+ def test_write_custom_column(self):
+ """Test writing with custom computed column."""
+ from apache_beam.ml.rag.ingestion.spanner import SpannerColumnSpecsBuilder
+
+ # Create custom database with computed column
+ self.spanner_helper.drop_database(self.database_id)
+ database = self.spanner_helper.instance.database(
+ self.database_id,
+ ddl_statements=[
+ f'''
+ CREATE TABLE {self.table_name} (
+ id STRING(1024) NOT NULL,
+ embedding ARRAY<FLOAT32>(vector_length=>3),
+ content STRING(MAX),
+ word_count INT64,
+ metadata JSON
+ ) PRIMARY KEY (id)'''
+ ])
+ database.create().result(120)
+
+ # Create test chunks
+ chunks = [
+ Chunk(
+ id='doc1',
+ embedding=Embedding(dense_embedding=[1.0, 2.0, 3.0]),
+ content=Content(text='Hello world test'),
+ metadata={}),
+ Chunk(
+ id='doc2',
+ embedding=Embedding(dense_embedding=[4.0, 5.0, 6.0]),
+ content=Content(text='This is a longer test document'),
+ metadata={}),
+ ]
+
+ # Create config with custom word_count column
+ specs = (
+ SpannerColumnSpecsBuilder().with_id_spec().with_embedding_spec(
+ ).with_content_spec().add_column(
+ column_name='word_count',
+ python_type=int,
+ value_fn=lambda chunk: len(chunk.content.text.split())).
+ with_metadata_spec().build())
+
+ config = SpannerVectorWriterConfig(
+ project_id=self.project_id,
+ instance_id=self.instance_id,
+ database_id=self.database_id,
+ table_name=self.table_name,
+ column_specs=specs,
+ emulator_host=self.spanner_helper.get_emulator_host(),
+ )
+
+ # Write chunks
+ with TestPipeline() as p:
+ p.not_use_test_runner_api = True
+ _ = (p | beam.Create(chunks) | config.create_write_transform())
+
+ # Verify data
+ database = self.spanner_helper.instance.database(self.database_id)
+ with database.snapshot() as snapshot:
+ results = snapshot.execute_sql(
+ f'SELECT id, word_count FROM {self.table_name} ORDER BY id')
+ rows = list(results)
+
+ self.assertEqual(len(rows), 2)
+ self.assertEqual(rows[0][1], 3) # "Hello world test" = 3 words
+ self.assertEqual(rows[1][1], 6) # 6 words
+
+ def test_write_with_timestamp(self):
+ """Test writing with timestamp columns."""
+ from apache_beam.ml.rag.ingestion.spanner import SpannerColumnSpecsBuilder
+
+ # Create database with timestamp column
+ self.spanner_helper.drop_database(self.database_id)
+ database = self.spanner_helper.instance.database(
+ self.database_id,
+ ddl_statements=[
+ f'''
+ CREATE TABLE {self.table_name} (
+ id STRING(1024) NOT NULL,
+ embedding ARRAY<FLOAT32>(vector_length=>3),
+ content STRING(MAX),
+ created_at TIMESTAMP,
+ metadata JSON
+ ) PRIMARY KEY (id)'''
+ ])
+ database.create().result(120)
+
+ # Create chunks with timestamp
+ timestamp_str = "2025-10-28T09:45:00.123456Z"
+ chunks = [
+ Chunk(
+ id='doc1',
+ embedding=Embedding(dense_embedding=[1.0, 2.0, 3.0]),
+ content=Content(text='Document with timestamp'),
+ metadata={'created_at': timestamp_str}),
+ ]
+
+ # Create config with timestamp field
+ specs = (
+ SpannerColumnSpecsBuilder().with_id_spec().with_embedding_spec().
+ with_content_spec().add_metadata_field(
+ 'created_at', str,
+ column_name='created_at').with_metadata_spec().build())
+
+ config = SpannerVectorWriterConfig(
+ project_id=self.project_id,
+ instance_id=self.instance_id,
+ database_id=self.database_id,
+ table_name=self.table_name,
+ column_specs=specs,
+ emulator_host=self.spanner_helper.get_emulator_host(),
+ )
+
+ # Write chunks
+ with TestPipeline() as p:
+ p.not_use_test_runner_api = True
+ _ = (p | beam.Create(chunks) | config.create_write_transform())
+
+ # Verify timestamp was written
+ database = self.spanner_helper.instance.database(self.database_id)
+ with database.snapshot() as snapshot:
+ results = snapshot.execute_sql(
+ f'SELECT id, created_at FROM {self.table_name}')
+ rows = list(results)
+
+ self.assertEqual(len(rows), 1)
+ self.assertEqual(rows[0][0], 'doc1')
+ # Timestamp is returned as datetime object by Spanner client
+ self.assertIsNotNone(rows[0][1])
+
+
+if __name__ == '__main__':
+ logging.getLogger().setLevel(logging.INFO)
+ unittest.main()