gemini-code-assist[bot] commented on code in PR #36654:
URL: https://github.com/apache/beam/pull/36654#discussion_r2473637571


##########
sdks/python/apache_beam/ml/rag/ingestion/spanner_it_test.py:
##########
@@ -0,0 +1,601 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Integration tests for Spanner vector store writer."""
+
+import logging
+import os
+import time
+import unittest
+import uuid
+
+import pytest
+
+import apache_beam as beam
+from apache_beam.ml.rag.ingestion.spanner import SpannerVectorWriterConfig
+from apache_beam.ml.rag.types import Chunk
+from apache_beam.ml.rag.types import Content
+from apache_beam.ml.rag.types import Embedding
+from apache_beam.testing.test_pipeline import TestPipeline
+
+# pylint: disable=wrong-import-order, wrong-import-position
+try:
+  from google.cloud import spanner
+except ImportError:
+  spanner = None
+
+try:
+  from testcontainers.core.container import DockerContainer
+except ImportError:
+  DockerContainer = None
+# pylint: enable=wrong-import-order, wrong-import-position
+
+
+def retry(fn, retries, err_msg, *args, **kwargs):
+  """Retry a function with exponential backoff."""
+  for _ in range(retries):
+    try:
+      return fn(*args, **kwargs)
+    except:  # pylint: disable=bare-except
+      time.sleep(1)
+  logging.error(err_msg)
+  raise RuntimeError(err_msg)

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The `retry` function has a couple of issues:
   1. The docstring claims "exponential backoff", but the implementation uses a 
fixed `time.sleep(1)`. The docstring should be updated to reflect the actual 
behavior (retrying with a fixed delay).
   2. Using a bare `except:` is discouraged as it can catch unexpected 
exceptions like `SystemExit` or `KeyboardInterrupt`. It's better to catch 
`Exception` to avoid this.
   
   ```suggestion
   def retry(fn, retries, err_msg, *args, **kwargs):
     """Retry a function with a fixed delay."""
     for _ in range(retries):
       try:
         return fn(*args, **kwargs)
       except Exception:
         time.sleep(1)
     logging.error(err_msg)
     raise RuntimeError(err_msg)
   ```



##########
sdks/python/apache_beam/ml/rag/ingestion/spanner.py:
##########
@@ -0,0 +1,660 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Cloud Spanner vector store writer for RAG pipelines.
+
+This module provides a writer for storing embeddings and associated metadata
+in Google Cloud Spanner. It supports flexible schema configuration with the
+ability to flatten metadata fields into dedicated columns.
+
+Example usage:
+
+    Default schema (id, embedding, content, metadata):
+    >>> config = SpannerVectorWriterConfig(
+    ...     project_id="my-project",
+    ...     instance_id="my-instance",
+    ...     database_id="my-db",
+    ...     table_name="embeddings"
+    ... )
+
+    Flattened metadata fields:
+    >>> specs = (
+    ...     SpannerColumnSpecsBuilder()
+    ...     .with_id_spec()
+    ...     .with_embedding_spec()
+    ...     .with_content_spec()
+    ...     .add_metadata_field("source", str)
+    ...     .add_metadata_field("page_number", int, default=0)
+    ...     .with_metadata_spec()
+    ...     .build()
+    ... )
+    >>> config = SpannerVectorWriterConfig(
+    ...     project_id="my-project",
+    ...     instance_id="my-instance",
+    ...     database_id="my-db",
+    ...     table_name="embeddings",
+    ...     column_specs=specs
+    ... )
+
+Spanner schema example:
+
+    CREATE TABLE embeddings (
+        id STRING(1024) NOT NULL,
+        embedding ARRAY<FLOAT32>(vector_length=>768),
+        content STRING(MAX),
+        source STRING(MAX),
+        page_number INT64,
+        metadata JSON
+    ) PRIMARY KEY (id)
+"""
+
+import functools
+import json
+from dataclasses import dataclass
+from typing import Any
+from typing import Callable
+from typing import List
+from typing import Literal
+from typing import NamedTuple
+from typing import Optional
+from typing import Type
+
+import apache_beam as beam
+from apache_beam.coders import registry
+from apache_beam.coders.row_coder import RowCoder
+from apache_beam.io.gcp import spanner
+from apache_beam.ml.rag.ingestion.base import VectorDatabaseWriteConfig
+from apache_beam.ml.rag.types import Chunk
+
+
+@dataclass
+class SpannerColumnSpec:
+  """Column specification for Spanner vector writes.
+  
+  Defines how to extract and format values from Chunks for insertion into
+  Spanner table columns. Each spec maps to one column in the target table.
+  
+  Attributes:
+      column_name: Name of the Spanner table column
+      python_type: Python type for the NamedTuple field (required for RowCoder)
+      value_fn: Function to extract value from a Chunk
+  
+  Examples:
+      String column:
+      >>> SpannerColumnSpec(
+      ...     column_name="id",
+      ...     python_type=str,
+      ...     value_fn=lambda chunk: chunk.id
+      ... )
+      
+      Array column with conversion:
+      >>> SpannerColumnSpec(
+      ...     column_name="embedding",
+      ...     python_type=List[float],
+      ...     value_fn=lambda chunk: chunk.embedding.dense_embedding
+      ... )
+  """
+  column_name: str
+  python_type: Type
+  value_fn: Callable[[Chunk], Any]
+
+
+def _extract_and_convert(extract_fn, convert_fn, chunk):
+  if convert_fn:
+    return convert_fn(extract_fn(chunk))
+  return extract_fn(chunk)
+
+
+class SpannerColumnSpecsBuilder:
+  """Builder for creating Spanner column specifications.
+  
+  Provides a fluent API for defining table schemas and how to populate them
+  from Chunk objects. Supports standard Chunk fields (id, embedding, content,
+  metadata) and flattening metadata fields into dedicated columns.
+  
+  Example:
+      >>> specs = (
+      ...     SpannerColumnSpecsBuilder()
+      ...     .with_id_spec()
+      ...     .with_embedding_spec()
+      ...     .with_content_spec()
+      ...     .add_metadata_field("source", str)
+      ...     .with_metadata_spec()
+      ...     .build()
+      ... )
+  """
+  def __init__(self):
+    self._specs: List[SpannerColumnSpec] = []
+
+  @staticmethod
+  def with_defaults() -> 'SpannerColumnSpecsBuilder':
+    """Create builder with default schema.
+    
+    Default schema includes:
+    - id (STRING): Chunk ID
+    - embedding (ARRAY<FLOAT32>): Dense embedding vector
+    - content (STRING): Chunk content text
+    - metadata (JSON): Full metadata as JSON
+    
+    Returns:
+        Builder with default column specifications
+    """
+    return (
+        SpannerColumnSpecsBuilder().with_id_spec().with_embedding_spec().
+        with_content_spec().with_metadata_spec())
+
+  def with_id_spec(
+      self,
+      column_name: str = "id",
+      python_type: Type = str,
+      extract_fn: Optional[Callable[[Chunk], Any]] = lambda chunk: chunk.id,
+      convert_fn: Optional[Callable[[Any], Any]] = None
+  ) -> 'SpannerColumnSpecsBuilder':
+    """Add ID column specification.
+    
+    Args:
+        column_name: Column name (default: "id")
+        python_type: Python type (default: str)
+        extract_fn: Value extractor (default: lambda chunk: chunk.id)
+        convert_fn: Optional converter (e.g., to cast to int)
+    
+    Returns:
+        Self for method chaining
+    
+    Examples:
+        Default string ID:
+        >>> builder.with_id_spec()
+        
+        Integer ID with conversion:
+        >>> builder.with_id_spec(
+        ...     python_type=int,
+        ...     convert_fn=lambda id: int(id.split('_')[1])
+        ... )
+    """
+
+    self._specs.append(
+        SpannerColumnSpec(
+            column_name=column_name,
+            python_type=python_type,
+            value_fn=functools.partial(
+                _extract_and_convert, extract_fn, convert_fn)))
+    return self
+
+  def with_embedding_spec(
+      self,
+      column_name: str = "embedding",
+      extract_fn: Optional[Callable[[Chunk], List[float]]] = None,
+      convert_fn: Optional[Callable[[List[float]], List[float]]] = None
+  ) -> 'SpannerColumnSpecsBuilder':
+    """Add embedding array column (ARRAY<FLOAT32> or ARRAY<FLOAT64>).
+    
+    Args:
+        column_name: Column name (default: "embedding")
+        extract_fn: Value extractor (default: chunk.embedding.dense_embedding)
+        convert_fn: Optional converter (e.g., normalize, quantize)
+    
+    Returns:
+        Self for method chaining
+    
+    Examples:
+        Default embedding:
+        >>> builder.with_embedding_spec()
+        
+        Normalized embedding:
+        >>> def normalize(vec):
+        ...     norm = (sum(x**2 for x in vec) ** 0.5) or 1.0
+        ...     return [x/norm for x in vec]
+        >>> builder.with_embedding_spec(convert_fn=normalize)
+        
+        Rounded precision:
+        >>> builder.with_embedding_spec(
+        ...     convert_fn=lambda vec: [round(x, 4) for x in vec]
+        ... )
+    """
+    def default_fn(chunk: Chunk) -> List[float]:
+      if chunk.embedding is None or chunk.embedding.dense_embedding is None:
+        raise ValueError(f'Chunk must contain embedding: {chunk}')
+      return chunk.embedding.dense_embedding
+
+    extract_fn = extract_fn or default_fn
+
+    self._specs.append(
+        SpannerColumnSpec(
+            column_name=column_name,
+            python_type=List[float],
+            value_fn=functools.partial(
+                _extract_and_convert, extract_fn, convert_fn)))
+    return self
+
+  def with_content_spec(
+      self,
+      column_name: str = "content",
+      python_type: Type = str,
+      extract_fn: Optional[Callable[[Chunk], Any]] = None,
+      convert_fn: Optional[Callable[[Any], Any]] = None
+  ) -> 'SpannerColumnSpecsBuilder':
+    """Add content column.
+    
+    Args:
+        column_name: Column name (default: "content")
+        python_type: Python type (default: str)
+        extract_fn: Value extractor (default: chunk.content.text)
+        convert_fn: Optional converter
+    
+    Returns:
+        Self for method chaining
+    
+    Examples:
+        Default text content:
+        >>> builder.with_content_spec()
+        
+        Content length as integer:
+        >>> builder.with_content_spec(
+        ...     column_name="content_length",
+        ...     python_type=int,
+        ...     convert_fn=lambda text: len(text.split())
+        ... )
+        
+        Truncated content:
+        >>> builder.with_content_spec(
+        ...     convert_fn=lambda text: text[:1000]
+        ... )
+    """
+    def default_fn(chunk: Chunk) -> str:
+      if chunk.content.text is None:
+        raise ValueError(f'Chunk must contain content: {chunk}')
+      return chunk.content.text
+
+    extract_fn = extract_fn or default_fn
+
+    self._specs.append(
+        SpannerColumnSpec(
+            column_name=column_name,
+            python_type=python_type,
+            value_fn=functools.partial(
+                _extract_and_convert, extract_fn, convert_fn)))
+    return self
+
+  def with_metadata_spec(
+      self,
+      column_name: str = "metadata",
+      value_fn: Optional[Callable[[Chunk], Any]] = None
+  ) -> 'SpannerColumnSpecsBuilder':
+    """Add metadata JSON column.
+    
+    Stores the full metadata dictionary as a JSON string in Spanner.
+    
+    Args:
+        column_name: Column name (default: "metadata")
+        value_fn: Value extractor (default: lambda chunk: chunk.metadata)
+    
+    Returns:
+        Self for method chaining
+    
+    Note:
+        Metadata is automatically converted to JSON string using json.dumps()
+    """
+    value_fn = value_fn or (lambda chunk: json.dumps(chunk.metadata))
+    self._specs.append(
+        SpannerColumnSpec(
+            column_name=column_name, python_type=str, value_fn=value_fn))
+    return self
+
+  def add_metadata_field(
+      self,
+      field: str,
+      python_type: Type,
+      column_name: Optional[str] = None,
+      convert_fn: Optional[Callable[[Any], Any]] = None,
+      default: Any = None) -> 'SpannerColumnSpecsBuilder':
+    """Flatten a metadata field into its own column.
+    
+    Extracts a specific field from chunk.metadata and stores it in a
+    dedicated table column.
+    
+    Args:
+        field: Key in chunk.metadata to extract
+        python_type: Python type (must be explicitly specified)
+        column_name: Column name (default: same as field)
+        convert_fn: Optional converter for type casting/transformation
+        default: Default value if field is missing from metadata
+    
+    Returns:
+        Self for method chaining
+    
+    Examples:
+        String field:
+        >>> builder.add_metadata_field("source", str)
+        
+        Integer with default:
+        >>> builder.add_metadata_field(
+        ...     "page_number",
+        ...     int,
+        ...     default=0
+        ... )
+        
+        Float with conversion:
+        >>> builder.add_metadata_field(
+        ...     "confidence",
+        ...     float,
+        ...     convert_fn=lambda x: round(float(x), 2),
+        ...     default=0.0
+        ... )
+        
+        List of strings:
+        >>> builder.add_metadata_field(
+        ...     "tags",
+        ...     List[str],
+        ...     default=[]
+        ... )
+        
+        Timestamp with conversion:
+        >>> builder.add_metadata_field(
+        ...     "created_at",
+        ...     str,
+        ...     convert_fn=lambda ts: ts.isoformat()
+        ... )
+    """
+    name = column_name or field
+
+    def value_fn(chunk: Chunk) -> Any:
+      return chunk.metadata.get(field, default)
+
+    self._specs.append(
+        SpannerColumnSpec(
+            column_name=name,
+            python_type=python_type,
+            value_fn=functools.partial(
+                _extract_and_convert, value_fn, convert_fn)))
+    return self
+
+  def add_column(
+      self,
+      column_name: str,
+      python_type: Type,
+      value_fn: Callable[[Chunk], Any]) -> 'SpannerColumnSpecsBuilder':
+    """Add a custom column with full control.
+    
+    Args:
+        column_name: Column name
+        python_type: Python type (required)
+        value_fn: Value extraction function
+    
+    Returns:
+        Self for method chaining
+    
+    Examples:
+        Boolean flag:
+        >>> builder.add_column(
+        ...     column_name="has_code",
+        ...     python_type=bool,
+        ...     value_fn=lambda chunk: "```" in chunk.content.text
+        ... )
+        
+        Computed value:
+        >>> builder.add_column(
+        ...     column_name="word_count",
+        ...     python_type=int,
+        ...     value_fn=lambda chunk: len(chunk.content.text.split())
+        ... )
+    """
+    self._specs.append(
+        SpannerColumnSpec(
+            column_name=column_name, python_type=python_type,
+            value_fn=value_fn))
+    return self
+
+  def build(self) -> List[SpannerColumnSpec]:
+    """Build the final list of column specifications.
+    
+    Returns:
+        List of SpannerColumnSpec objects
+    """
+    return self._specs.copy()
+
+
+class _SpannerSchemaBuilder:
+  """Internal: Builds NamedTuple schema and registers RowCoder.
+  
+  Creates a NamedTuple type from column specifications and registers it
+  with Beam's RowCoder for serialization.
+  """
+  def __init__(self, table_name: str, column_specs: List[SpannerColumnSpec]):
+    """Initialize schema builder.
+    
+    Args:
+        table_name: Table name (used in NamedTuple type name)
+        column_specs: List of column specifications
+    
+    Raises:
+        ValueError: If duplicate column names are found
+    """
+    self.table_name = table_name
+    self.column_specs = column_specs
+
+    # Validate no duplicates
+    names = [col.column_name for col in column_specs]
+    duplicates = set(name for name in names if names.count(name) > 1)
+    if duplicates:
+      raise ValueError(f"Duplicate column names: {duplicates}")

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The current implementation for finding duplicate column names has a time 
complexity of O(n^2), which can be inefficient if there are many columns. A 
more performant approach with O(n) complexity can be achieved by using a set to 
track seen column names. This also allows for a deterministic error message by 
sorting the duplicates.
   
   ```suggestion
       seen = set()
       duplicates = set()
       for name in names:
         if name in seen:
           duplicates.add(name)
         seen.add(name)
       if duplicates:
         raise ValueError(f"Duplicate column names: {sorted(list(duplicates))}")
   ```



##########
sdks/python/apache_beam/ml/rag/ingestion/spanner_it_test.py:
##########
@@ -0,0 +1,601 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Integration tests for Spanner vector store writer."""
+
+import logging
+import os
+import time
+import unittest
+import uuid
+
+import pytest
+
+import apache_beam as beam
+from apache_beam.ml.rag.ingestion.spanner import SpannerVectorWriterConfig
+from apache_beam.ml.rag.types import Chunk
+from apache_beam.ml.rag.types import Content
+from apache_beam.ml.rag.types import Embedding
+from apache_beam.testing.test_pipeline import TestPipeline
+
+# pylint: disable=wrong-import-order, wrong-import-position
+try:
+  from google.cloud import spanner
+except ImportError:
+  spanner = None
+
+try:
+  from testcontainers.core.container import DockerContainer
+except ImportError:
+  DockerContainer = None
+# pylint: enable=wrong-import-order, wrong-import-position
+
+
+def retry(fn, retries, err_msg, *args, **kwargs):
+  """Retry a function with exponential backoff."""
+  for _ in range(retries):
+    try:
+      return fn(*args, **kwargs)
+    except:  # pylint: disable=bare-except
+      time.sleep(1)
+  logging.error(err_msg)
+  raise RuntimeError(err_msg)
+
+
+class SpannerEmulatorHelper:
+  """Helper for managing Spanner emulator lifecycle."""
+  def __init__(self, project_id: str, instance_id: str, table_name: str):
+    self.project_id = project_id
+    self.instance_id = instance_id
+    self.table_name = table_name
+    self.host = None
+
+    # Start emulator
+    self.emulator = DockerContainer(
+        'gcr.io/cloud-spanner-emulator/emulator:latest').with_exposed_ports(
+            9010, 9020)
+    retry(self.emulator.start, 3, 'Could not start spanner emulator.')
+    time.sleep(3)
+
+    self.host = f'{self.emulator.get_container_host_ip()}:' \
+                f'{self.emulator.get_exposed_port(9010)}'
+    os.environ['SPANNER_EMULATOR_HOST'] = self.host
+
+    # Create client and instance
+    self.client = spanner.Client(project_id)
+    self.instance = self.client.instance(instance_id)
+    self.create_instance()
+
+  def create_instance(self):
+    """Create Spanner instance in emulator."""
+    self.instance.create().result(120)
+
+  def create_database(self, database_id: str):
+    """Create database with default vector table schema."""
+    database = self.instance.database(
+        database_id,
+        ddl_statements=[
+            f'''
+          CREATE TABLE {self.table_name} (
+              id STRING(1024) NOT NULL,
+              embedding ARRAY<FLOAT32>(vector_length=>3),
+              content STRING(MAX),
+              metadata JSON
+          ) PRIMARY KEY (id)'''
+        ])
+    database.create().result(120)
+
+  def read_data(self, database_id: str):
+    """Read all data from the table."""
+    database = self.instance.database(database_id)
+    with database.snapshot() as snapshot:
+      results = snapshot.execute_sql(
+          f'SELECT * FROM {self.table_name} ORDER BY id')
+      return list(results) if results else []
+
+  def drop_database(self, database_id: str):
+    """Drop the database."""
+    database = self.instance.database(database_id)
+    database.drop()
+
+  def shutdown(self):
+    """Stop the emulator."""
+    if self.emulator:
+      try:
+        self.emulator.stop()
+      except:  # pylint: disable=bare-except
+        logging.error('Could not stop Spanner emulator.')
+
+  def get_emulator_host(self) -> str:
+    """Get the emulator host URL."""
+    return f'http://{self.host}'
+
+
[email protected]_gcp_java_expansion_service
[email protected](
+    os.environ.get('EXPANSION_JARS'),
+    "EXPANSION_JARS environment var is not provided, "
+    "indicating that jars have not been built")
[email protected](spanner is None, 'GCP dependencies are not installed.')
[email protected](
+    DockerContainer is None, 'testcontainers package is not installed.')
+class SpannerVectorWriterTest(unittest.TestCase):
+  """Integration tests for Spanner vector writer."""
+  @classmethod
+  def setUpClass(cls):
+    """Set up Spanner emulator for all tests."""
+    cls.project_id = 'test-project'
+    cls.instance_id = 'test-instance'
+    cls.table_name = 'embeddings'
+
+    cls.spanner_helper = SpannerEmulatorHelper(
+        cls.project_id, cls.instance_id, cls.table_name)
+
+  @classmethod
+  def tearDownClass(cls):
+    """Tear down Spanner emulator."""
+    cls.spanner_helper.shutdown()
+
+  def setUp(self):
+    """Create a unique database for each test."""
+    self.database_id = f'test_db_{uuid.uuid4().hex}'[:30]
+    self.spanner_helper.create_database(self.database_id)
+
+  def tearDown(self):
+    """Drop the test database."""
+    self.spanner_helper.drop_database(self.database_id)
+
+  def test_write_default_schema(self):
+    """Test writing with default schema (id, embedding, content, metadata)."""
+    # Create test chunks
+    chunks = [
+        Chunk(
+            id='doc1',
+            embedding=Embedding(dense_embedding=[1.0, 2.0, 3.0]),
+            content=Content(text='First document'),
+            metadata={
+                'source': 'test', 'page': 1
+            }),
+        Chunk(
+            id='doc2',
+            embedding=Embedding(dense_embedding=[4.0, 5.0, 6.0]),
+            content=Content(text='Second document'),
+            metadata={
+                'source': 'test', 'page': 2
+            }),
+    ]
+
+    # Create config with default schema
+    config = SpannerVectorWriterConfig(
+        project_id=self.project_id,
+        instance_id=self.instance_id,
+        database_id=self.database_id,
+        table_name=self.table_name,
+        emulator_host=self.spanner_helper.get_emulator_host(),
+    )
+
+    # Write chunks
+    with TestPipeline() as p:
+      p.not_use_test_runner_api = True
+      _ = (p | beam.Create(chunks) | config.create_write_transform())
+
+    # Verify data was written
+    results = self.spanner_helper.read_data(self.database_id)
+    self.assertEqual(len(results), 2)
+
+    # Check first row
+    row1 = results[0]
+    self.assertEqual(row1[0], 'doc1')  # id
+    self.assertEqual(list(row1[1]), [1.0, 2.0, 3.0])  # embedding
+    self.assertEqual(row1[2], 'First document')  # content
+    # metadata is JSON
+    metadata1 = row1[3]
+    self.assertEqual(metadata1['source'], 'test')
+    self.assertEqual(metadata1['page'], 1)
+
+    # Check second row
+    row2 = results[1]
+    self.assertEqual(row2[0], 'doc2')
+    self.assertEqual(list(row2[1]), [4.0, 5.0, 6.0])
+    self.assertEqual(row2[2], 'Second document')
+
+  def test_write_flattened_metadata(self):
+    """Test writing with flattened metadata fields."""
+    from apache_beam.ml.rag.ingestion.spanner import SpannerColumnSpecsBuilder

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   Imports should generally be at the top of the file, as per PEP 8 guidelines. 
Please move this import to the top of the file with the other imports. This 
applies to similar local imports in other test methods in this file.



##########
sdks/python/apache_beam/ml/rag/ingestion/spanner_it_test.py:
##########
@@ -0,0 +1,601 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Integration tests for Spanner vector store writer."""
+
+import logging
+import os
+import time
+import unittest
+import uuid
+
+import pytest
+
+import apache_beam as beam
+from apache_beam.ml.rag.ingestion.spanner import SpannerVectorWriterConfig
+from apache_beam.ml.rag.types import Chunk
+from apache_beam.ml.rag.types import Content
+from apache_beam.ml.rag.types import Embedding
+from apache_beam.testing.test_pipeline import TestPipeline
+
+# pylint: disable=wrong-import-order, wrong-import-position
+try:
+  from google.cloud import spanner
+except ImportError:
+  spanner = None
+
+try:
+  from testcontainers.core.container import DockerContainer
+except ImportError:
+  DockerContainer = None
+# pylint: enable=wrong-import-order, wrong-import-position
+
+
+def retry(fn, retries, err_msg, *args, **kwargs):
+  """Retry a function with exponential backoff."""
+  for _ in range(retries):
+    try:
+      return fn(*args, **kwargs)
+    except:  # pylint: disable=bare-except
+      time.sleep(1)
+  logging.error(err_msg)
+  raise RuntimeError(err_msg)
+
+
+class SpannerEmulatorHelper:
+  """Helper for managing Spanner emulator lifecycle."""
+  def __init__(self, project_id: str, instance_id: str, table_name: str):
+    self.project_id = project_id
+    self.instance_id = instance_id
+    self.table_name = table_name
+    self.host = None
+
+    # Start emulator
+    self.emulator = DockerContainer(
+        'gcr.io/cloud-spanner-emulator/emulator:latest').with_exposed_ports(
+            9010, 9020)
+    retry(self.emulator.start, 3, 'Could not start spanner emulator.')
+    time.sleep(3)
+
+    self.host = f'{self.emulator.get_container_host_ip()}:' \
+                f'{self.emulator.get_exposed_port(9010)}'
+    os.environ['SPANNER_EMULATOR_HOST'] = self.host
+
+    # Create client and instance
+    self.client = spanner.Client(project_id)
+    self.instance = self.client.instance(instance_id)
+    self.create_instance()
+
+  def create_instance(self):
+    """Create Spanner instance in emulator."""
+    self.instance.create().result(120)
+
+  def create_database(self, database_id: str):
+    """Create database with default vector table schema."""
+    database = self.instance.database(
+        database_id,
+        ddl_statements=[
+            f'''
+          CREATE TABLE {self.table_name} (
+              id STRING(1024) NOT NULL,
+              embedding ARRAY<FLOAT32>(vector_length=>3),
+              content STRING(MAX),
+              metadata JSON
+          ) PRIMARY KEY (id)'''
+        ])
+    database.create().result(120)
+
+  def read_data(self, database_id: str):
+    """Read all data from the table."""
+    database = self.instance.database(database_id)
+    with database.snapshot() as snapshot:
+      results = snapshot.execute_sql(
+          f'SELECT * FROM {self.table_name} ORDER BY id')
+      return list(results) if results else []
+
+  def drop_database(self, database_id: str):
+    """Drop the database."""
+    database = self.instance.database(database_id)
+    database.drop()
+
+  def shutdown(self):
+    """Stop the emulator."""
+    if self.emulator:
+      try:
+        self.emulator.stop()
+      except:  # pylint: disable=bare-except
+        logging.error('Could not stop Spanner emulator.')

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   Using a bare `except:` is discouraged as it can catch unexpected exceptions 
like `SystemExit` or `KeyboardInterrupt`. It's better to catch `Exception` to 
avoid unintended side effects.
   
   ```suggestion
         try:
           self.emulator.stop()
         except Exception:
           logging.error('Could not stop Spanner emulator.')
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to