This is an automated email from the ASF dual-hosted git repository.

damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 6b4dc555a24 Add spannerio vector writer. (#36654)
6b4dc555a24 is described below

commit 6b4dc555a245b4f82250677de21890955867314c
Author: claudevdm <[email protected]>
AuthorDate: Thu Oct 30 16:41:53 2025 -0400

    Add spannerio vector writer. (#36654)
    
    * draft
    
    * Simplify
    
    * Add xlang test markers.
    
    * Fix tests.
    
    * lints
    
    * Remove extractor_fn args.
    
    * Linter.
    
    ---------
    
    Co-authored-by: Claude <[email protected]>
---
 .../beam_PostCommit_Python_Xlang_Gcp_Direct.json   |   2 +-
 .../python/apache_beam/ml/rag/ingestion/spanner.py | 646 +++++++++++++++++++++
 .../ml/rag/ingestion/spanner_it_test.py            | 601 +++++++++++++++++++
 3 files changed, 1248 insertions(+), 1 deletion(-)

diff --git a/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json 
b/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json
index 2504db607e4..95fef3e26ca 100644
--- a/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json
+++ b/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json
@@ -1,4 +1,4 @@
 {
   "comment": "Modify this file in a trivial way to cause this test suite to 
run",
-  "modification": 12
+  "modification": 13
 }
diff --git a/sdks/python/apache_beam/ml/rag/ingestion/spanner.py 
b/sdks/python/apache_beam/ml/rag/ingestion/spanner.py
new file mode 100644
index 00000000000..f79db470bca
--- /dev/null
+++ b/sdks/python/apache_beam/ml/rag/ingestion/spanner.py
@@ -0,0 +1,646 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Cloud Spanner vector store writer for RAG pipelines.
+
+This module provides a writer for storing embeddings and associated metadata
+in Google Cloud Spanner. It supports flexible schema configuration with the
+ability to flatten metadata fields into dedicated columns.
+
+Example usage:
+
+    Default schema (id, embedding, content, metadata):
+    >>> config = SpannerVectorWriterConfig(
+    ...     project_id="my-project",
+    ...     instance_id="my-instance",
+    ...     database_id="my-db",
+    ...     table_name="embeddings"
+    ... )
+
+    Flattened metadata fields:
+    >>> specs = (
+    ...     SpannerColumnSpecsBuilder()
+    ...     .with_id_spec()
+    ...     .with_embedding_spec()
+    ...     .with_content_spec()
+    ...     .add_metadata_field("source", str)
+    ...     .add_metadata_field("page_number", int, default=0)
+    ...     .with_metadata_spec()
+    ...     .build()
+    ... )
+    >>> config = SpannerVectorWriterConfig(
+    ...     project_id="my-project",
+    ...     instance_id="my-instance",
+    ...     database_id="my-db",
+    ...     table_name="embeddings",
+    ...     column_specs=specs
+    ... )
+
+Spanner schema example:
+
+    CREATE TABLE embeddings (
+        id STRING(1024) NOT NULL,
+        embedding ARRAY<FLOAT32>(vector_length=>768),
+        content STRING(MAX),
+        source STRING(MAX),
+        page_number INT64,
+        metadata JSON
+    ) PRIMARY KEY (id)
+"""
+
+import functools
+import json
+from dataclasses import dataclass
+from typing import Any
+from typing import Callable
+from typing import List
+from typing import Literal
+from typing import NamedTuple
+from typing import Optional
+from typing import Type
+
+import apache_beam as beam
+from apache_beam.coders import registry
+from apache_beam.coders.row_coder import RowCoder
+from apache_beam.io.gcp import spanner
+from apache_beam.ml.rag.ingestion.base import VectorDatabaseWriteConfig
+from apache_beam.ml.rag.types import Chunk
+
+
+@dataclass
+class SpannerColumnSpec:
+  """Column specification for Spanner vector writes.
+  
+  Defines how to extract and format values from Chunks for insertion into
+  Spanner table columns. Each spec maps to one column in the target table.
+  
+  Attributes:
+      column_name: Name of the Spanner table column
+      python_type: Python type for the NamedTuple field (required for RowCoder)
+      value_fn: Function to extract value from a Chunk
+  
+  Examples:
+      String column:
+      >>> SpannerColumnSpec(
+      ...     column_name="id",
+      ...     python_type=str,
+      ...     value_fn=lambda chunk: chunk.id
+      ... )
+      
+      Array column with conversion:
+      >>> SpannerColumnSpec(
+      ...     column_name="embedding",
+      ...     python_type=List[float],
+      ...     value_fn=lambda chunk: chunk.embedding.dense_embedding
+      ... )
+  """
+  column_name: str
+  python_type: Type
+  value_fn: Callable[[Chunk], Any]
+
+
+def _extract_and_convert(extract_fn, convert_fn, chunk):
+  if convert_fn:
+    return convert_fn(extract_fn(chunk))
+  return extract_fn(chunk)
+
+
+class SpannerColumnSpecsBuilder:
+  """Builder for creating Spanner column specifications.
+  
+  Provides a fluent API for defining table schemas and how to populate them
+  from Chunk objects. Supports standard Chunk fields (id, embedding, content,
+  metadata) and flattening metadata fields into dedicated columns.
+  
+  Example:
+      >>> specs = (
+      ...     SpannerColumnSpecsBuilder()
+      ...     .with_id_spec()
+      ...     .with_embedding_spec()
+      ...     .with_content_spec()
+      ...     .add_metadata_field("source", str)
+      ...     .with_metadata_spec()
+      ...     .build()
+      ... )
+  """
+  def __init__(self):
+    self._specs: List[SpannerColumnSpec] = []
+
+  @staticmethod
+  def with_defaults() -> 'SpannerColumnSpecsBuilder':
+    """Create builder with default schema.
+    
+    Default schema includes:
+    - id (STRING): Chunk ID
+    - embedding (ARRAY<FLOAT32>): Dense embedding vector
+    - content (STRING): Chunk content text
+    - metadata (JSON): Full metadata as JSON
+    
+    Returns:
+        Builder with default column specifications
+    """
+    return (
+        SpannerColumnSpecsBuilder().with_id_spec().with_embedding_spec().
+        with_content_spec().with_metadata_spec())
+
+  def with_id_spec(
+      self,
+      column_name: str = "id",
+      python_type: Type = str,
+      convert_fn: Optional[Callable[[str], Any]] = None
+  ) -> 'SpannerColumnSpecsBuilder':
+    """Add ID column specification.
+    
+    Args:
+        column_name: Column name (default: "id")
+        python_type: Python type (default: str)
+        convert_fn: Optional converter (e.g., to cast to int)
+    
+    Returns:
+        Self for method chaining
+    
+    Examples:
+        Default string ID:
+        >>> builder.with_id_spec()
+        
+        Integer ID with conversion:
+        >>> builder.with_id_spec(
+        ...     python_type=int,
+        ...     convert_fn=lambda id: int(id.split('_')[1])
+        ... )
+    """
+
+    self._specs.append(
+        SpannerColumnSpec(
+            column_name=column_name,
+            python_type=python_type,
+            value_fn=functools.partial(
+                _extract_and_convert, lambda chunk: chunk.id, convert_fn)))
+    return self
+
+  def with_embedding_spec(
+      self,
+      column_name: str = "embedding",
+      convert_fn: Optional[Callable[[List[float]], List[float]]] = None
+  ) -> 'SpannerColumnSpecsBuilder':
+    """Add embedding array column (ARRAY<FLOAT32> or ARRAY<FLOAT64>).
+    
+    Args:
+        column_name: Column name (default: "embedding")
+        convert_fn: Optional converter (e.g., normalize, quantize)
+    
+    Returns:
+        Self for method chaining
+    
+    Examples:
+        Default embedding:
+        >>> builder.with_embedding_spec()
+        
+        Normalized embedding:
+        >>> def normalize(vec):
+        ...     norm = (sum(x**2 for x in vec) ** 0.5) or 1.0
+        ...     return [x/norm for x in vec]
+        >>> builder.with_embedding_spec(convert_fn=normalize)
+        
+        Rounded precision:
+        >>> builder.with_embedding_spec(
+        ...     convert_fn=lambda vec: [round(x, 4) for x in vec]
+        ... )
+    """
+    def extract_fn(chunk: Chunk) -> List[float]:
+      if chunk.embedding is None or chunk.embedding.dense_embedding is None:
+        raise ValueError(f'Chunk must contain embedding: {chunk}')
+      return chunk.embedding.dense_embedding
+
+    self._specs.append(
+        SpannerColumnSpec(
+            column_name=column_name,
+            python_type=List[float],
+            value_fn=functools.partial(
+                _extract_and_convert, extract_fn, convert_fn)))
+    return self
+
+  def with_content_spec(
+      self,
+      column_name: str = "content",
+      python_type: Type = str,
+      convert_fn: Optional[Callable[[str], Any]] = None
+  ) -> 'SpannerColumnSpecsBuilder':
+    """Add content column.
+    
+    Args:
+        column_name: Column name (default: "content")
+        python_type: Python type (default: str)
+        convert_fn: Optional converter
+    
+    Returns:
+        Self for method chaining
+    
+    Examples:
+        Default text content:
+        >>> builder.with_content_spec()
+        
+        Content length as integer:
+        >>> builder.with_content_spec(
+        ...     column_name="content_length",
+        ...     python_type=int,
+        ...     convert_fn=lambda text: len(text.split())
+        ... )
+        
+        Truncated content:
+        >>> builder.with_content_spec(
+        ...     convert_fn=lambda text: text[:1000]
+        ... )
+    """
+    def extract_fn(chunk: Chunk) -> str:
+      if chunk.content.text is None:
+        raise ValueError(f'Chunk must contain content: {chunk}')
+      return chunk.content.text
+
+    self._specs.append(
+        SpannerColumnSpec(
+            column_name=column_name,
+            python_type=python_type,
+            value_fn=functools.partial(
+                _extract_and_convert, extract_fn, convert_fn)))
+    return self
+
+  def with_metadata_spec(
+      self, column_name: str = "metadata") -> 'SpannerColumnSpecsBuilder':
+    """Add metadata JSON column.
+    
+    Stores the full metadata dictionary as a JSON string in Spanner.
+    
+    Args:
+        column_name: Column name (default: "metadata")
+    
+    Returns:
+        Self for method chaining
+    
+    Note:
+        Metadata is automatically converted to JSON string using json.dumps()
+    """
+    value_fn = lambda chunk: json.dumps(chunk.metadata)
+    self._specs.append(
+        SpannerColumnSpec(
+            column_name=column_name, python_type=str, value_fn=value_fn))
+    return self
+
+  def add_metadata_field(
+      self,
+      field: str,
+      python_type: Type,
+      column_name: Optional[str] = None,
+      convert_fn: Optional[Callable[[Any], Any]] = None,
+      default: Any = None) -> 'SpannerColumnSpecsBuilder':
+    """Flatten a metadata field into its own column.
+    
+    Extracts a specific field from chunk.metadata and stores it in a
+    dedicated table column.
+    
+    Args:
+        field: Key in chunk.metadata to extract
+        python_type: Python type (must be explicitly specified)
+        column_name: Column name (default: same as field)
+        convert_fn: Optional converter for type casting/transformation
+        default: Default value if field is missing from metadata
+    
+    Returns:
+        Self for method chaining
+    
+    Examples:
+        String field:
+        >>> builder.add_metadata_field("source", str)
+        
+        Integer with default:
+        >>> builder.add_metadata_field(
+        ...     "page_number",
+        ...     int,
+        ...     default=0
+        ... )
+        
+        Float with conversion:
+        >>> builder.add_metadata_field(
+        ...     "confidence",
+        ...     float,
+        ...     convert_fn=lambda x: round(float(x), 2),
+        ...     default=0.0
+        ... )
+        
+        List of strings:
+        >>> builder.add_metadata_field(
+        ...     "tags",
+        ...     List[str],
+        ...     default=[]
+        ... )
+        
+        Timestamp with conversion:
+        >>> builder.add_metadata_field(
+        ...     "created_at",
+        ...     str,
+        ...     convert_fn=lambda ts: ts.isoformat()
+        ... )
+    """
+    name = column_name or field
+
+    def value_fn(chunk: Chunk) -> Any:
+      return chunk.metadata.get(field, default)
+
+    self._specs.append(
+        SpannerColumnSpec(
+            column_name=name,
+            python_type=python_type,
+            value_fn=functools.partial(
+                _extract_and_convert, value_fn, convert_fn)))
+    return self
+
+  def add_column(
+      self,
+      column_name: str,
+      python_type: Type,
+      value_fn: Callable[[Chunk], Any]) -> 'SpannerColumnSpecsBuilder':
+    """Add a custom column with full control.
+    
+    Args:
+        column_name: Column name
+        python_type: Python type (required)
+        value_fn: Value extraction function
+    
+    Returns:
+        Self for method chaining
+    
+    Examples:
+        Boolean flag:
+        >>> builder.add_column(
+        ...     column_name="has_code",
+        ...     python_type=bool,
+        ...     value_fn=lambda chunk: "```" in chunk.content.text
+        ... )
+        
+        Computed value:
+        >>> builder.add_column(
+        ...     column_name="word_count",
+        ...     python_type=int,
+        ...     value_fn=lambda chunk: len(chunk.content.text.split())
+        ... )
+    """
+    self._specs.append(
+        SpannerColumnSpec(
+            column_name=column_name, python_type=python_type,
+            value_fn=value_fn))
+    return self
+
+  def build(self) -> List[SpannerColumnSpec]:
+    """Build the final list of column specifications.
+    
+    Returns:
+        List of SpannerColumnSpec objects
+    """
+    return self._specs.copy()
+
+
+class _SpannerSchemaBuilder:
+  """Internal: Builds NamedTuple schema and registers RowCoder.
+  
+  Creates a NamedTuple type from column specifications and registers it
+  with Beam's RowCoder for serialization.
+  """
+  def __init__(self, table_name: str, column_specs: List[SpannerColumnSpec]):
+    """Initialize schema builder.
+    
+    Args:
+        table_name: Table name (used in NamedTuple type name)
+        column_specs: List of column specifications
+    
+    Raises:
+        ValueError: If duplicate column names are found
+    """
+    self.table_name = table_name
+    self.column_specs = column_specs
+
+    # Validate no duplicates
+    names = [col.column_name for col in column_specs]
+    duplicates = set(name for name in names if names.count(name) > 1)
+    if duplicates:
+      raise ValueError(f"Duplicate column names: {duplicates}")
+
+    # Create NamedTuple type
+    fields = [(col.column_name, col.python_type) for col in column_specs]
+    type_name = f"SpannerVectorRecord_{table_name}"
+    self.record_type = NamedTuple(type_name, fields)  # type: ignore
+
+    # Register coder
+    registry.register_coder(self.record_type, RowCoder)
+
+  def create_converter(self) -> Callable[[Chunk], NamedTuple]:
+    """Create converter function from Chunk to NamedTuple record.
+    
+    Returns:
+        Function that converts a Chunk to a NamedTuple record
+    """
+    def convert(chunk: Chunk) -> self.record_type:  # type: ignore
+      values = {
+          col.column_name: col.value_fn(chunk)
+          for col in self.column_specs
+      }
+      return self.record_type(**values)  # type: ignore
+
+    return convert
+
+
+class SpannerVectorWriterConfig(VectorDatabaseWriteConfig):
+  """Configuration for writing vectors to Cloud Spanner.
+  
+  Supports flexible schema configuration through column specifications and
+  provides control over Spanner-specific write parameters.
+  
+  Examples:
+      Default schema:
+      >>> config = SpannerVectorWriterConfig(
+      ...     project_id="my-project",
+      ...     instance_id="my-instance",
+      ...     database_id="my-db",
+      ...     table_name="embeddings"
+      ... )
+      
+      Custom schema with flattened metadata:
+      >>> specs = (
+      ...     SpannerColumnSpecsBuilder()
+      ...     .with_id_spec()
+      ...     .with_embedding_spec()
+      ...     .with_content_spec()
+      ...     .add_metadata_field("source", str)
+      ...     .add_metadata_field("page_number", int, default=0)
+      ...     .with_metadata_spec()
+      ...     .build()
+      ... )
+      >>> config = SpannerVectorWriterConfig(
+      ...     project_id="my-project",
+      ...     instance_id="my-instance",
+      ...     database_id="my-db",
+      ...     table_name="embeddings",
+      ...     column_specs=specs
+      ... )
+      
+      With emulator:
+      >>> config = SpannerVectorWriterConfig(
+      ...     project_id="test-project",
+      ...     instance_id="test-instance",
+      ...     database_id="test-db",
+      ...     table_name="embeddings",
+      ...     emulator_host="http://localhost:9010";
+      ... )
+  """
+  def __init__(
+      self,
+      project_id: str,
+      instance_id: str,
+      database_id: str,
+      table_name: str,
+      *,
+      # Schema configuration
+      column_specs: Optional[List[SpannerColumnSpec]] = None,
+      # Write operation type
+      write_mode: Literal["INSERT", "UPDATE", "REPLACE",
+                          "INSERT_OR_UPDATE"] = "INSERT_OR_UPDATE",
+      # Batching configuration
+      max_batch_size_bytes: Optional[int] = None,
+      max_number_mutations: Optional[int] = None,
+      max_number_rows: Optional[int] = None,
+      grouping_factor: Optional[int] = None,
+      # Networking
+      host: Optional[str] = None,
+      emulator_host: Optional[str] = None,
+      expansion_service: Optional[str] = None,
+      # Retry/deadline configuration
+      commit_deadline: Optional[int] = None,
+      max_cumulative_backoff: Optional[int] = None,
+      # Error handling
+      failure_mode: Optional[
+          spanner.FailureMode] = spanner.FailureMode.REPORT_FAILURES,
+      high_priority: bool = False,
+      # Additional Spanner arguments
+      **spanner_kwargs):
+    """Initialize Spanner vector writer configuration.
+    
+    Args:
+        project_id: GCP project ID
+        instance_id: Spanner instance ID
+        database_id: Spanner database ID
+        table_name: Target table name
+        column_specs: Schema configuration using SpannerColumnSpecsBuilder.
+            If None, uses default schema (id, embedding, content, metadata)
+        write_mode: Spanner write operation type:
+            - INSERT: Fail if row exists
+            - UPDATE: Fail if row doesn't exist
+            - REPLACE: Delete then insert
+            - INSERT_OR_UPDATE: Insert or update if exists (default)
+        max_batch_size_bytes: Maximum bytes per mutation batch (default: 1MB)
+        max_number_mutations: Maximum cell mutations per batch (default: 5000)
+        max_number_rows: Maximum rows per batch (default: 500)
+        grouping_factor: Multiple of max mutation for sorting (default: 1000)
+        host: Spanner host URL (usually not needed)
+        emulator_host: Spanner emulator host (e.g., "http://localhost:9010";)
+        expansion_service: Java expansion service address (host:port)
+        commit_deadline: Commit API deadline in seconds (default: 15)
+        max_cumulative_backoff: Max retry backoff seconds (default: 900)
+        failure_mode: Error handling strategy:
+            - FAIL_FAST: Throw exception for any failure
+            - REPORT_FAILURES: Continue processing (default)
+        high_priority: Use high priority for operations (default: False)
+        **spanner_kwargs: Additional keyword arguments to pass to the
+            underlying Spanner write transform. Use this to pass any
+            Spanner-specific parameters not explicitly exposed by this config.
+    """
+    self.project_id = project_id
+    self.instance_id = instance_id
+    self.database_id = database_id
+    self.table_name = table_name
+    self.write_mode = write_mode
+    self.max_batch_size_bytes = max_batch_size_bytes
+    self.max_number_mutations = max_number_mutations
+    self.max_number_rows = max_number_rows
+    self.grouping_factor = grouping_factor
+    self.host = host
+    self.emulator_host = emulator_host
+    self.expansion_service = expansion_service
+    self.commit_deadline = commit_deadline
+    self.max_cumulative_backoff = max_cumulative_backoff
+    self.failure_mode = failure_mode
+    self.high_priority = high_priority
+    self.spanner_kwargs = spanner_kwargs
+
+    # Use defaults if not provided
+    specs = column_specs or SpannerColumnSpecsBuilder.with_defaults().build()
+
+    # Create schema builder (NamedTuple + RowCoder registration)
+    self.schema_builder = _SpannerSchemaBuilder(table_name, specs)
+
+  def create_write_transform(self) -> beam.PTransform:
+    """Create the Spanner write PTransform.
+    
+    Returns:
+        PTransform for writing to Spanner
+    """
+    return _WriteToSpannerVectorDatabase(self)
+
+
+class _WriteToSpannerVectorDatabase(beam.PTransform):
+  """Internal: PTransform for writing to Spanner vector database."""
+  def __init__(self, config: SpannerVectorWriterConfig):
+    """Initialize write transform.
+    
+    Args:
+        config: Spanner writer configuration
+    """
+    self.config = config
+    self.schema_builder = config.schema_builder
+
+  def expand(self, pcoll: beam.PCollection[Chunk]):
+    """Expand the transform.
+    
+    Args:
+        pcoll: PCollection of Chunks to write
+    """
+    # Select appropriate Spanner write transform based on write_mode
+    write_transform_class = {
+        "INSERT": spanner.SpannerInsert,
+        "UPDATE": spanner.SpannerUpdate,
+        "REPLACE": spanner.SpannerReplace,
+        "INSERT_OR_UPDATE": spanner.SpannerInsertOrUpdate,
+    }[self.config.write_mode]
+
+    return (
+        pcoll
+        | "Convert to Records" >> beam.Map(
+            self.schema_builder.create_converter()).with_output_types(
+                self.schema_builder.record_type)
+        | "Write to Spanner" >> write_transform_class(
+            project_id=self.config.project_id,
+            instance_id=self.config.instance_id,
+            database_id=self.config.database_id,
+            table=self.config.table_name,
+            max_batch_size_bytes=self.config.max_batch_size_bytes,
+            max_number_mutations=self.config.max_number_mutations,
+            max_number_rows=self.config.max_number_rows,
+            grouping_factor=self.config.grouping_factor,
+            host=self.config.host,
+            emulator_host=self.config.emulator_host,
+            commit_deadline=self.config.commit_deadline,
+            max_cumulative_backoff=self.config.max_cumulative_backoff,
+            failure_mode=self.config.failure_mode,
+            expansion_service=self.config.expansion_service,
+            high_priority=self.config.high_priority,
+            **self.config.spanner_kwargs))
diff --git a/sdks/python/apache_beam/ml/rag/ingestion/spanner_it_test.py 
b/sdks/python/apache_beam/ml/rag/ingestion/spanner_it_test.py
new file mode 100644
index 00000000000..ab9a982a81f
--- /dev/null
+++ b/sdks/python/apache_beam/ml/rag/ingestion/spanner_it_test.py
@@ -0,0 +1,601 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Integration tests for Spanner vector store writer."""
+
+import logging
+import os
+import time
+import unittest
+import uuid
+
+import pytest
+
+import apache_beam as beam
+from apache_beam.ml.rag.ingestion.spanner import SpannerVectorWriterConfig
+from apache_beam.ml.rag.types import Chunk
+from apache_beam.ml.rag.types import Content
+from apache_beam.ml.rag.types import Embedding
+from apache_beam.testing.test_pipeline import TestPipeline
+
+# pylint: disable=wrong-import-order, wrong-import-position
+try:
+  from google.cloud import spanner
+except ImportError:
+  spanner = None
+
+try:
+  from testcontainers.core.container import DockerContainer
+except ImportError:
+  DockerContainer = None
+# pylint: enable=wrong-import-order, wrong-import-position
+
+
+def retry(fn, retries, err_msg, *args, **kwargs):
+  """Retry a function with exponential backoff."""
+  for _ in range(retries):
+    try:
+      return fn(*args, **kwargs)
+    except:  # pylint: disable=bare-except
+      time.sleep(1)
+  logging.error(err_msg)
+  raise RuntimeError(err_msg)
+
+
+class SpannerEmulatorHelper:
+  """Helper for managing Spanner emulator lifecycle."""
+  def __init__(self, project_id: str, instance_id: str, table_name: str):
+    self.project_id = project_id
+    self.instance_id = instance_id
+    self.table_name = table_name
+    self.host = None
+
+    # Start emulator
+    self.emulator = DockerContainer(
+        'gcr.io/cloud-spanner-emulator/emulator:latest').with_exposed_ports(
+            9010, 9020)
+    retry(self.emulator.start, 3, 'Could not start spanner emulator.')
+    time.sleep(3)
+
+    self.host = f'{self.emulator.get_container_host_ip()}:' \
+                f'{self.emulator.get_exposed_port(9010)}'
+    os.environ['SPANNER_EMULATOR_HOST'] = self.host
+
+    # Create client and instance
+    self.client = spanner.Client(project_id)
+    self.instance = self.client.instance(instance_id)
+    self.create_instance()
+
+  def create_instance(self):
+    """Create Spanner instance in emulator."""
+    self.instance.create().result(120)
+
+  def create_database(self, database_id: str):
+    """Create database with default vector table schema."""
+    database = self.instance.database(
+        database_id,
+        ddl_statements=[
+            f'''
+          CREATE TABLE {self.table_name} (
+              id STRING(1024) NOT NULL,
+              embedding ARRAY<FLOAT32>(vector_length=>3),
+              content STRING(MAX),
+              metadata JSON
+          ) PRIMARY KEY (id)'''
+        ])
+    database.create().result(120)
+
+  def read_data(self, database_id: str):
+    """Read all data from the table."""
+    database = self.instance.database(database_id)
+    with database.snapshot() as snapshot:
+      results = snapshot.execute_sql(
+          f'SELECT * FROM {self.table_name} ORDER BY id')
+      return list(results) if results else []
+
+  def drop_database(self, database_id: str):
+    """Drop the database."""
+    database = self.instance.database(database_id)
+    database.drop()
+
+  def shutdown(self):
+    """Stop the emulator."""
+    if self.emulator:
+      try:
+        self.emulator.stop()
+      except:  # pylint: disable=bare-except
+        logging.error('Could not stop Spanner emulator.')
+
+  def get_emulator_host(self) -> str:
+    """Get the emulator host URL."""
+    return f'http://{self.host}'
+
+
[email protected]_gcp_java_expansion_service
[email protected](
+    os.environ.get('EXPANSION_JARS'),
+    "EXPANSION_JARS environment var is not provided, "
+    "indicating that jars have not been built")
[email protected](spanner is None, 'GCP dependencies are not installed.')
[email protected](
+    DockerContainer is None, 'testcontainers package is not installed.')
+class SpannerVectorWriterTest(unittest.TestCase):
+  """Integration tests for Spanner vector writer."""
+  @classmethod
+  def setUpClass(cls):
+    """Set up Spanner emulator for all tests."""
+    cls.project_id = 'test-project'
+    cls.instance_id = 'test-instance'
+    cls.table_name = 'embeddings'
+
+    cls.spanner_helper = SpannerEmulatorHelper(
+        cls.project_id, cls.instance_id, cls.table_name)
+
+  @classmethod
+  def tearDownClass(cls):
+    """Tear down Spanner emulator."""
+    cls.spanner_helper.shutdown()
+
+  def setUp(self):
+    """Create a unique database for each test."""
+    self.database_id = f'test_db_{uuid.uuid4().hex}'[:30]
+    self.spanner_helper.create_database(self.database_id)
+
+  def tearDown(self):
+    """Drop the test database."""
+    self.spanner_helper.drop_database(self.database_id)
+
+  def test_write_default_schema(self):
+    """Test writing with default schema (id, embedding, content, metadata)."""
+    # Create test chunks
+    chunks = [
+        Chunk(
+            id='doc1',
+            embedding=Embedding(dense_embedding=[1.0, 2.0, 3.0]),
+            content=Content(text='First document'),
+            metadata={
+                'source': 'test', 'page': 1
+            }),
+        Chunk(
+            id='doc2',
+            embedding=Embedding(dense_embedding=[4.0, 5.0, 6.0]),
+            content=Content(text='Second document'),
+            metadata={
+                'source': 'test', 'page': 2
+            }),
+    ]
+
+    # Create config with default schema
+    config = SpannerVectorWriterConfig(
+        project_id=self.project_id,
+        instance_id=self.instance_id,
+        database_id=self.database_id,
+        table_name=self.table_name,
+        emulator_host=self.spanner_helper.get_emulator_host(),
+    )
+
+    # Write chunks
+    with TestPipeline() as p:
+      p.not_use_test_runner_api = True
+      _ = (p | beam.Create(chunks) | config.create_write_transform())
+
+    # Verify data was written
+    results = self.spanner_helper.read_data(self.database_id)
+    self.assertEqual(len(results), 2)
+
+    # Check first row
+    row1 = results[0]
+    self.assertEqual(row1[0], 'doc1')  # id
+    self.assertEqual(list(row1[1]), [1.0, 2.0, 3.0])  # embedding
+    self.assertEqual(row1[2], 'First document')  # content
+    # metadata is JSON
+    metadata1 = row1[3]
+    self.assertEqual(metadata1['source'], 'test')
+    self.assertEqual(metadata1['page'], 1)
+
+    # Check second row
+    row2 = results[1]
+    self.assertEqual(row2[0], 'doc2')
+    self.assertEqual(list(row2[1]), [4.0, 5.0, 6.0])
+    self.assertEqual(row2[2], 'Second document')
+
+  def test_write_flattened_metadata(self):
+    """Test writing with flattened metadata fields."""
+    from apache_beam.ml.rag.ingestion.spanner import SpannerColumnSpecsBuilder
+
+    # Create custom database with flattened columns
+    self.spanner_helper.drop_database(self.database_id)
+    database = self.spanner_helper.instance.database(
+        self.database_id,
+        ddl_statements=[
+            f'''
+          CREATE TABLE {self.table_name} (
+              id STRING(1024) NOT NULL,
+              embedding ARRAY<FLOAT32>(vector_length=>3),
+              content STRING(MAX),
+              source STRING(MAX),
+              page_number INT64,
+              metadata JSON
+          ) PRIMARY KEY (id)'''
+        ])
+    database.create().result(120)
+
+    # Create test chunks
+    chunks = [
+        Chunk(
+            id='doc1',
+            embedding=Embedding(dense_embedding=[1.0, 2.0, 3.0]),
+            content=Content(text='First document'),
+            metadata={
+                'source': 'book.pdf', 'page': 10, 'author': 'John'
+            }),
+        Chunk(
+            id='doc2',
+            embedding=Embedding(dense_embedding=[4.0, 5.0, 6.0]),
+            content=Content(text='Second document'),
+            metadata={
+                'source': 'article.txt', 'page': 5, 'author': 'Jane'
+            }),
+    ]
+
+    # Create config with flattened metadata
+    specs = (
+        SpannerColumnSpecsBuilder().with_id_spec().with_embedding_spec().
+        with_content_spec().add_metadata_field(
+            'source', str, column_name='source').add_metadata_field(
+                'page', int,
+                column_name='page_number').with_metadata_spec().build())
+
+    config = SpannerVectorWriterConfig(
+        project_id=self.project_id,
+        instance_id=self.instance_id,
+        database_id=self.database_id,
+        table_name=self.table_name,
+        column_specs=specs,
+        emulator_host=self.spanner_helper.get_emulator_host(),
+    )
+
+    # Write chunks
+    with TestPipeline() as p:
+      p.not_use_test_runner_api = True
+      _ = (p | beam.Create(chunks) | config.create_write_transform())
+
+    # Verify data
+    database = self.spanner_helper.instance.database(self.database_id)
+    with database.snapshot() as snapshot:
+      results = snapshot.execute_sql(
+          f'SELECT id, embedding, content, source, page_number, metadata '
+          f'FROM {self.table_name} ORDER BY id')
+      rows = list(results)
+
+    self.assertEqual(len(rows), 2)
+
+    # Check first row
+    self.assertEqual(rows[0][0], 'doc1')
+    self.assertEqual(list(rows[0][1]), [1.0, 2.0, 3.0])
+    self.assertEqual(rows[0][2], 'First document')
+    self.assertEqual(rows[0][3], 'book.pdf')  # flattened source
+    self.assertEqual(rows[0][4], 10)  # flattened page_number
+
+    metadata1 = rows[0][5]
+    self.assertEqual(metadata1['author'], 'John')
+
+  def test_write_minimal_schema(self):
+    """Test writing with minimal schema (only id and embedding)."""
+    from apache_beam.ml.rag.ingestion.spanner import SpannerColumnSpecsBuilder
+
+    # Create custom database with minimal schema
+    self.spanner_helper.drop_database(self.database_id)
+    database = self.spanner_helper.instance.database(
+        self.database_id,
+        ddl_statements=[
+            f'''
+          CREATE TABLE {self.table_name} (
+              id STRING(1024) NOT NULL,
+              embedding ARRAY<FLOAT32>(vector_length=>3)
+          ) PRIMARY KEY (id)'''
+        ])
+    database.create().result(120)
+
+    # Create test chunks
+    chunks = [
+        Chunk(
+            id='doc1',
+            embedding=Embedding(dense_embedding=[1.0, 2.0, 3.0]),
+            content=Content(text='First document'),
+            metadata={'source': 'test'}),
+        Chunk(
+            id='doc2',
+            embedding=Embedding(dense_embedding=[4.0, 5.0, 6.0]),
+            content=Content(text='Second document'),
+            metadata={'source': 'test'}),
+    ]
+
+    # Create config with minimal schema
+    specs = (
+        SpannerColumnSpecsBuilder().with_id_spec().with_embedding_spec().build(
+        ))
+
+    config = SpannerVectorWriterConfig(
+        project_id=self.project_id,
+        instance_id=self.instance_id,
+        database_id=self.database_id,
+        table_name=self.table_name,
+        column_specs=specs,
+        emulator_host=self.spanner_helper.get_emulator_host(),
+    )
+
+    # Write chunks
+    with TestPipeline() as p:
+      p.not_use_test_runner_api = True
+      _ = (p | beam.Create(chunks) | config.create_write_transform())
+
+    # Verify data
+    results = self.spanner_helper.read_data(self.database_id)
+    self.assertEqual(len(results), 2)
+    self.assertEqual(results[0][0], 'doc1')
+    self.assertEqual(list(results[0][1]), [1.0, 2.0, 3.0])
+
+  def test_write_with_converter(self):
+    """Test writing with custom converter function."""
+    from apache_beam.ml.rag.ingestion.spanner import SpannerColumnSpecsBuilder
+
+    # Create test chunks with embeddings that need normalization
+    chunks = [
+        Chunk(
+            id='doc1',
+            embedding=Embedding(dense_embedding=[3.0, 4.0, 0.0]),
+            content=Content(text='First document'),
+            metadata={'source': 'test'}),
+    ]
+
+    # Define normalizer
+    def normalize(vec):
+      norm = (sum(x**2 for x in vec)**0.5) or 1.0
+      return [x / norm for x in vec]
+
+    # Create config with normalized embeddings
+    specs = (
+        SpannerColumnSpecsBuilder().with_id_spec().with_embedding_spec(
+            convert_fn=normalize).with_content_spec().with_metadata_spec().
+        build())
+
+    config = SpannerVectorWriterConfig(
+        project_id=self.project_id,
+        instance_id=self.instance_id,
+        database_id=self.database_id,
+        table_name=self.table_name,
+        column_specs=specs,
+        emulator_host=self.spanner_helper.get_emulator_host(),
+    )
+
+    # Write chunks
+    with TestPipeline() as p:
+      p.not_use_test_runner_api = True
+      _ = (p | beam.Create(chunks) | config.create_write_transform())
+
+    # Verify data - embedding should be normalized
+    results = self.spanner_helper.read_data(self.database_id)
+    self.assertEqual(len(results), 1)
+
+    embedding = list(results[0][1])
+    # Original was [3.0, 4.0, 0.0], normalized should be [0.6, 0.8, 0.0]
+    self.assertAlmostEqual(embedding[0], 0.6, places=5)
+    self.assertAlmostEqual(embedding[1], 0.8, places=5)
+    self.assertAlmostEqual(embedding[2], 0.0, places=5)
+
+    # Check norm is 1.0
+    norm = sum(x**2 for x in embedding)**0.5
+    self.assertAlmostEqual(norm, 1.0, places=5)
+
+  def test_write_update_mode(self):
+    """Test writing with UPDATE mode."""
+    # First insert data
+    chunks_insert = [
+        Chunk(
+            id='doc1',
+            embedding=Embedding(dense_embedding=[1.0, 2.0, 3.0]),
+            content=Content(text='Original content'),
+            metadata={'version': 1}),
+    ]
+
+    config_insert = SpannerVectorWriterConfig(
+        project_id=self.project_id,
+        instance_id=self.instance_id,
+        database_id=self.database_id,
+        table_name=self.table_name,
+        write_mode='INSERT',
+        emulator_host=self.spanner_helper.get_emulator_host(),
+    )
+
+    with TestPipeline() as p:
+      p.not_use_test_runner_api = True
+      _ = (
+          p
+          | beam.Create(chunks_insert)
+          | config_insert.create_write_transform())
+
+    # Update existing row
+    chunks_update = [
+        Chunk(
+            id='doc1',
+            embedding=Embedding(dense_embedding=[4.0, 5.0, 6.0]),
+            content=Content(text='Updated content'),
+            metadata={'version': 2}),
+    ]
+
+    config_update = SpannerVectorWriterConfig(
+        project_id=self.project_id,
+        instance_id=self.instance_id,
+        database_id=self.database_id,
+        table_name=self.table_name,
+        write_mode='UPDATE',
+        emulator_host=self.spanner_helper.get_emulator_host(),
+    )
+
+    with TestPipeline() as p:
+      p.not_use_test_runner_api = True
+      _ = (
+          p
+          | beam.Create(chunks_update)
+          | config_update.create_write_transform())
+
+    # Verify update succeeded
+    results = self.spanner_helper.read_data(self.database_id)
+    self.assertEqual(len(results), 1)
+    self.assertEqual(results[0][0], 'doc1')
+    self.assertEqual(list(results[0][1]), [4.0, 5.0, 6.0])
+    self.assertEqual(results[0][2], 'Updated content')
+
+    metadata = results[0][3]
+    self.assertEqual(metadata['version'], 2)
+
+  def test_write_custom_column(self):
+    """Test writing with custom computed column."""
+    from apache_beam.ml.rag.ingestion.spanner import SpannerColumnSpecsBuilder
+
+    # Create custom database with computed column
+    self.spanner_helper.drop_database(self.database_id)
+    database = self.spanner_helper.instance.database(
+        self.database_id,
+        ddl_statements=[
+            f'''
+          CREATE TABLE {self.table_name} (
+              id STRING(1024) NOT NULL,
+              embedding ARRAY<FLOAT32>(vector_length=>3),
+              content STRING(MAX),
+              word_count INT64,
+              metadata JSON
+          ) PRIMARY KEY (id)'''
+        ])
+    database.create().result(120)
+
+    # Create test chunks
+    chunks = [
+        Chunk(
+            id='doc1',
+            embedding=Embedding(dense_embedding=[1.0, 2.0, 3.0]),
+            content=Content(text='Hello world test'),
+            metadata={}),
+        Chunk(
+            id='doc2',
+            embedding=Embedding(dense_embedding=[4.0, 5.0, 6.0]),
+            content=Content(text='This is a longer test document'),
+            metadata={}),
+    ]
+
+    # Create config with custom word_count column
+    specs = (
+        SpannerColumnSpecsBuilder().with_id_spec().with_embedding_spec(
+        ).with_content_spec().add_column(
+            column_name='word_count',
+            python_type=int,
+            value_fn=lambda chunk: len(chunk.content.text.split())).
+        with_metadata_spec().build())
+
+    config = SpannerVectorWriterConfig(
+        project_id=self.project_id,
+        instance_id=self.instance_id,
+        database_id=self.database_id,
+        table_name=self.table_name,
+        column_specs=specs,
+        emulator_host=self.spanner_helper.get_emulator_host(),
+    )
+
+    # Write chunks
+    with TestPipeline() as p:
+      p.not_use_test_runner_api = True
+      _ = (p | beam.Create(chunks) | config.create_write_transform())
+
+    # Verify data
+    database = self.spanner_helper.instance.database(self.database_id)
+    with database.snapshot() as snapshot:
+      results = snapshot.execute_sql(
+          f'SELECT id, word_count FROM {self.table_name} ORDER BY id')
+      rows = list(results)
+
+    self.assertEqual(len(rows), 2)
+    self.assertEqual(rows[0][1], 3)  # "Hello world test" = 3 words
+    self.assertEqual(rows[1][1], 6)  # 6 words
+
+  def test_write_with_timestamp(self):
+    """Test writing with timestamp columns."""
+    from apache_beam.ml.rag.ingestion.spanner import SpannerColumnSpecsBuilder
+
+    # Create database with timestamp column
+    self.spanner_helper.drop_database(self.database_id)
+    database = self.spanner_helper.instance.database(
+        self.database_id,
+        ddl_statements=[
+            f'''
+          CREATE TABLE {self.table_name} (
+              id STRING(1024) NOT NULL,
+              embedding ARRAY<FLOAT32>(vector_length=>3),
+              content STRING(MAX),
+              created_at TIMESTAMP,
+              metadata JSON
+          ) PRIMARY KEY (id)'''
+        ])
+    database.create().result(120)
+
+    # Create chunks with timestamp
+    timestamp_str = "2025-10-28T09:45:00.123456Z"
+    chunks = [
+        Chunk(
+            id='doc1',
+            embedding=Embedding(dense_embedding=[1.0, 2.0, 3.0]),
+            content=Content(text='Document with timestamp'),
+            metadata={'created_at': timestamp_str}),
+    ]
+
+    # Create config with timestamp field
+    specs = (
+        SpannerColumnSpecsBuilder().with_id_spec().with_embedding_spec().
+        with_content_spec().add_metadata_field(
+            'created_at', str,
+            column_name='created_at').with_metadata_spec().build())
+
+    config = SpannerVectorWriterConfig(
+        project_id=self.project_id,
+        instance_id=self.instance_id,
+        database_id=self.database_id,
+        table_name=self.table_name,
+        column_specs=specs,
+        emulator_host=self.spanner_helper.get_emulator_host(),
+    )
+
+    # Write chunks
+    with TestPipeline() as p:
+      p.not_use_test_runner_api = True
+      _ = (p | beam.Create(chunks) | config.create_write_transform())
+
+    # Verify timestamp was written
+    database = self.spanner_helper.instance.database(self.database_id)
+    with database.snapshot() as snapshot:
+      results = snapshot.execute_sql(
+          f'SELECT id, created_at FROM {self.table_name}')
+      rows = list(results)
+
+    self.assertEqual(len(rows), 1)
+    self.assertEqual(rows[0][0], 'doc1')
+    # Timestamp is returned as datetime object by Spanner client
+    self.assertIsNotNone(rows[0][1])
+
+
+if __name__ == '__main__':
+  logging.getLogger().setLevel(logging.INFO)
+  unittest.main()


Reply via email to