This is an automated email from the ASF dual-hosted git repository.

damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 83ebe731133 [1/3] sdks/python: refactor Milvus-related utilities as 
preparation step for Milvus Sink I/O integration (#35708)
83ebe731133 is described below

commit 83ebe73113391c1650680d2665a15849b536e776
Author: Mohamed Awnallah <[email protected]>
AuthorDate: Wed Nov 12 20:50:46 2025 +0200

    [1/3] sdks/python: refactor Milvus-related utilities as preparation step 
for Milvus Sink I/O integration (#35708)
    
    * sdks/python: replace the deprecated testcontainer max tries
    
    * sdks/python: handle transient testcontainer startup/teardown errors
    
    * sdks/python: bump `testcontainers` py pkg version
    
    * sdks/python: integrate milvus sink I/O
    
    * sdks/python: fix linting issues
    
    * sdks/python: add missing apache beam liscense header for `test_utils.py`
    
    * notebooks/beam-ml: use new refactored code in milvus enrichment handler
    
    * CHANGES.md: update release notes
    
    * sdks/python: mark milvus itests with `require_docker_in_docker` marker
    
    * sdks/python: override milvus db version with the default
    
    * sdsk/python: add missing import in rag utils
    
    * sdks/python: fix linting issue
    
    * rag/ingestion/milvus_search_itest.py: ensure flushing in-memory data 
before querying
    
    * sdks/python: fix linting issues
    
    * sdks/python: fix formatting issues
    
    * sdks/python: fix arising linting issue
    
    * rag: reuse `retry_with_backoff` for one-time setup operations
    
    * sdks/python: fix linting issues
    
    * sdks/python: fix py docs CI issue
    
    * sdks/python: fix linting issues
    
    * sdks/python: fix linting issues
    
    * sdks/python: isolate milvus sink integration to be in follow-up PR
    
    * CHANGES.md: remove milvus from release notes in the refactoring PR
    
    * sdks/python: remove `with_sparse_embedding_spec` column specs builder
    
    In this commit, we remove that builder method to remain functional
    and be used in the next Milvus sink integration PR
    
    * sdks/python: fix linting issues
    
    * Revert "notebooks/beam-ml: use new refactored code in milvus enrichment 
handler"
    
    This reverts commit 461c8fee9d1d4b63b63558d188f88f3e79856309.
    
    * sdks/python: fix linting issues
    
    * sdks/python: fix linting issues
    
    * sdks/python: fix linting issues
    
    * sdks/python: fix linting issues
    
    * CI: fix import errors in CI
    
    * sdks/python: fix linting issues
    
    * sdks/python: fix linting issues
    
    * sdks/python: fix linting issues
    
    * sdks/python: fix linting issues
---
 .../transforms/elementwise/enrichment_test.py      |  60 +--
 .../apache_beam/ml/rag/enrichment/milvus_search.py | 133 ++-----
 .../ml/rag/enrichment/milvus_search_it_test.py     | 429 ++-------------------
 .../ml/rag/ingestion/postgres_common.py            |  56 +--
 sdks/python/apache_beam/ml/rag/test_utils.py       | 413 ++++++++++++++++++++
 sdks/python/apache_beam/ml/rag/utils.py            | 224 +++++++++++
 6 files changed, 759 insertions(+), 556 deletions(-)

diff --git 
a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py
 
b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py
index c8e988a52c5..ed2b0c131e0 100644
--- 
a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py
+++ 
b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py
@@ -52,13 +52,11 @@ try:
       ConnectionConfig,
       CloudSQLConnectionConfig,
       ExternalSQLDBConnectionConfig)
-  from apache_beam.ml.rag.enrichment.milvus_search import (
-      MilvusConnectionParameters)
-  from apache_beam.ml.rag.enrichment.milvus_search_it_test import (
-      MilvusEnrichmentTestHelper,
-      MilvusDBContainerInfo,
-      parse_chunk_strings,
-      assert_chunks_equivalent)
+  from apache_beam.ml.rag.enrichment.milvus_search import 
MilvusConnectionParameters
+  from apache_beam.ml.rag.test_utils import MilvusTestHelpers
+  from apache_beam.ml.rag.test_utils import VectorDBContainerInfo
+  from apache_beam.ml.rag.test_utils import MilvusTestHelpers
+  from apache_beam.ml.rag.utils import parse_chunk_strings
   from apache_beam.io.requestresponse import RequestResponseIO
 except ImportError as e:
   raise unittest.SkipTest(f'Examples dependencies are not installed: {str(e)}')
@@ -69,6 +67,11 @@ class TestContainerStartupError(Exception):
   pass
 
 
+class TestContainerTeardownError(Exception):
+  """Raised when any test container fails to teardown."""
+  pass
+
+
 def validate_enrichment_with_bigtable():
   expected = '''[START enrichment_with_bigtable]
 Row(sale_id=1, customer_id=1, product_id=1, quantity=1, product={'product_id': 
'1', 'product_name': 'pixel 5', 'product_stock': '2'})
@@ -186,7 +189,7 @@ class EnrichmentTest(unittest.TestCase):
         output = mock_stdout.getvalue().splitlines()
         expected = validate_enrichment_with_external_pg()
         self.assertEqual(output, expected)
-    except TestContainerStartupError as e:
+    except (TestContainerStartupError, TestContainerTeardownError) as e:
       raise unittest.SkipTest(str(e))
     except Exception as e:
       self.fail(f"Test failed with unexpected error: {e}")
@@ -199,7 +202,7 @@ class EnrichmentTest(unittest.TestCase):
         output = mock_stdout.getvalue().splitlines()
         expected = validate_enrichment_with_external_mysql()
         self.assertEqual(output, expected)
-    except TestContainerStartupError as e:
+    except (TestContainerStartupError, TestContainerTeardownError) as e:
       raise unittest.SkipTest(str(e))
     except Exception as e:
       self.fail(f"Test failed with unexpected error: {e}")
@@ -212,7 +215,7 @@ class EnrichmentTest(unittest.TestCase):
         output = mock_stdout.getvalue().splitlines()
         expected = validate_enrichment_with_external_sqlserver()
         self.assertEqual(output, expected)
-    except TestContainerStartupError as e:
+    except (TestContainerStartupError, TestContainerTeardownError) as e:
       raise unittest.SkipTest(str(e))
     except Exception as e:
       self.fail(f"Test failed with unexpected error: {e}")
@@ -226,8 +229,8 @@ class EnrichmentTest(unittest.TestCase):
         self.maxDiff = None
         output = parse_chunk_strings(output)
         expected = parse_chunk_strings(expected)
-        assert_chunks_equivalent(output, expected)
-    except TestContainerStartupError as e:
+        MilvusTestHelpers.assert_chunks_equivalent(output, expected)
+    except (TestContainerStartupError, TestContainerTeardownError) as e:
       raise unittest.SkipTest(str(e))
     except Exception as e:
       self.fail(f"Test failed with unexpected error: {e}")
@@ -257,7 +260,7 @@ class EnrichmentTestHelpers:
   @staticmethod
   @contextmanager
   def milvus_test_context():
-    db: Optional[MilvusDBContainerInfo] = None
+    db: Optional[VectorDBContainerInfo] = None
     try:
       db = EnrichmentTestHelpers.pre_milvus_enrichment()
       yield
@@ -370,23 +373,21 @@ class EnrichmentTestHelpers:
       os.environ.pop('GOOGLE_CLOUD_SQL_DB_TABLE_ID', None)
 
   @staticmethod
-  def pre_milvus_enrichment() -> MilvusDBContainerInfo:
+  def pre_milvus_enrichment() -> VectorDBContainerInfo:
     try:
-      db = MilvusEnrichmentTestHelper.start_db_container()
+      db = MilvusTestHelpers.start_db_container()
+      connection_params = MilvusConnectionParameters(
+          uri=db.uri,
+          user=db.user,
+          password=db.password,
+          db_id=db.id,
+          token=db.token)
+      collection_name = MilvusTestHelpers.initialize_db_with_data(
+          connection_params)
     except Exception as e:
       raise TestContainerStartupError(
           f"Milvus container failed to start: {str(e)}")
 
-    connection_params = MilvusConnectionParameters(
-        uri=db.uri,
-        user=db.user,
-        password=db.password,
-        db_id=db.id,
-        token=db.token)
-
-    collection_name = MilvusEnrichmentTestHelper.initialize_db_with_data(
-        connection_params)
-
     # Setup environment variables for db and collection configuration. This 
will
     # be used downstream by the milvus enrichment handler.
     os.environ['MILVUS_VECTOR_DB_URI'] = db.uri
@@ -399,8 +400,13 @@ class EnrichmentTestHelpers:
     return db
 
   @staticmethod
-  def post_milvus_enrichment(db: MilvusDBContainerInfo):
-    MilvusEnrichmentTestHelper.stop_db_container(db)
+  def post_milvus_enrichment(db: VectorDBContainerInfo):
+    try:
+      MilvusTestHelpers.stop_db_container(db)
+    except Exception as e:
+      raise TestContainerTeardownError(
+          f"Milvus container failed to tear down: {str(e)}")
+
     os.environ.pop('MILVUS_VECTOR_DB_URI', None)
     os.environ.pop('MILVUS_VECTOR_DB_USER', None)
     os.environ.pop('MILVUS_VECTOR_DB_PASSWORD', None)
diff --git a/sdks/python/apache_beam/ml/rag/enrichment/milvus_search.py 
b/sdks/python/apache_beam/ml/rag/enrichment/milvus_search.py
index 8f631746748..41355e8c10a 100644
--- a/sdks/python/apache_beam/ml/rag/enrichment/milvus_search.py
+++ b/sdks/python/apache_beam/ml/rag/enrichment/milvus_search.py
@@ -32,9 +32,14 @@ from pymilvus import Hit
 from pymilvus import Hits
 from pymilvus import MilvusClient
 from pymilvus import SearchResult
+from pymilvus.exceptions import MilvusException
 
 from apache_beam.ml.rag.types import Chunk
 from apache_beam.ml.rag.types import Embedding
+from apache_beam.ml.rag.utils import MilvusConnectionParameters
+from apache_beam.ml.rag.utils import MilvusHelpers
+from apache_beam.ml.rag.utils import retry_with_backoff
+from apache_beam.ml.rag.utils import unpack_dataclass_with_kwargs
 from apache_beam.transforms.enrichment import EnrichmentSourceHandler
 
 
@@ -104,44 +109,6 @@ class MilvusBaseRanker:
     return self.dict().__str__()
 
 
-@dataclass
-class MilvusConnectionParameters:
-  """Parameters for establishing connections to Milvus servers.
-
-  Args:
-    uri: URI endpoint for connecting to Milvus server in the format
-      "http(s)://hostname:port".
-    user: Username for authentication. Required if authentication is enabled 
and
-      not using token authentication.
-    password: Password for authentication. Required if authentication is 
enabled
-      and not using token authentication.
-    db_id: Database ID to connect to. Specifies which Milvus database to use.
-      Defaults to 'default'.
-    token: Authentication token as an alternative to username/password.
-    timeout: Connection timeout in seconds. Uses client default if None.
-    max_retries: Maximum number of connection retry attempts. Defaults to 3.
-    retry_delay: Initial delay between retries in seconds. Defaults to 1.0.
-    retry_backoff_factor: Multiplier for retry delay after each attempt. 
-      Defaults to 2.0 (exponential backoff).
-    kwargs: Optional keyword arguments for additional connection parameters.
-      Enables forward compatibility.
-  """
-  uri: str
-  user: str = field(default_factory=str)
-  password: str = field(default_factory=str)
-  db_id: str = "default"
-  token: str = field(default_factory=str)
-  timeout: Optional[float] = None
-  max_retries: int = 3
-  retry_delay: float = 1.0
-  retry_backoff_factor: float = 2.0
-  kwargs: Dict[str, Any] = field(default_factory=dict)
-
-  def __post_init__(self):
-    if not self.uri:
-      raise ValueError("URI must be provided for Milvus connection")
-
-
 @dataclass
 class BaseSearchParameters:
   """Base parameters for both vector and keyword search operations.
@@ -361,7 +328,7 @@ class 
MilvusSearchEnrichmentHandler(EnrichmentSourceHandler[InputT, OutputT]):
       **kwargs):
     """
     Example Usage:
-      connection_paramters = MilvusConnectionParameters(
+      connection_parameters = MilvusConnectionParameters(
         uri="http://localhost:19530";)
       search_parameters = MilvusSearchParameters(
         collection_name="my_collection",
@@ -369,7 +336,7 @@ class 
MilvusSearchEnrichmentHandler(EnrichmentSourceHandler[InputT, OutputT]):
       collection_load_parameters = MilvusCollectionLoadParameters(
         load_fields=["embedding", "metadata"]),
       milvus_handler = MilvusSearchEnrichmentHandler(
-        connection_paramters,
+        connection_parameters,
         search_parameters,
         collection_load_parameters=collection_load_parameters,
         min_batch_size=10,
@@ -407,52 +374,43 @@ class 
MilvusSearchEnrichmentHandler(EnrichmentSourceHandler[InputT, OutputT]):
         'min_batch_size': min_batch_size, 'max_batch_size': max_batch_size
     }
     self.kwargs = kwargs
+    self._client = None
     self.join_fn = join_fn
     self.use_custom_types = True
 
   def __enter__(self):
-    import logging
-    import time
-
-    from pymilvus.exceptions import MilvusException
-
-    connection_params = unpack_dataclass_with_kwargs(
-        self._connection_parameters)
-    collection_load_params = unpack_dataclass_with_kwargs(
-        self._collection_load_parameters)
-
-    # Extract retry parameters from connection_params
-    max_retries = connection_params.pop('max_retries', 3)
-    retry_delay = connection_params.pop('retry_delay', 1.0)
-    retry_backoff_factor = connection_params.pop('retry_backoff_factor', 2.0)
-
-    # Retry logic for MilvusClient connection
-    last_exception = None
-    for attempt in range(max_retries + 1):
-      try:
-        self._client = MilvusClient(**connection_params)
-        self._client.load_collection(
+    """Enters the context manager and establishes Milvus connection.
+
+    Returns:
+      Self, enabling use in 'with' statements.
+    """
+    if not self._client:
+      connection_params = unpack_dataclass_with_kwargs(
+          self._connection_parameters)
+      collection_load_params = unpack_dataclass_with_kwargs(
+          self._collection_load_parameters)
+
+      # Extract retry parameters from connection_params.
+      max_retries = connection_params.pop('max_retries', 3)
+      retry_delay = connection_params.pop('retry_delay', 1.0)
+      retry_backoff_factor = connection_params.pop('retry_backoff_factor', 2.0)
+
+      def connect_and_load():
+        client = MilvusClient(**connection_params)
+        client.load_collection(
             collection_name=self.collection_name,
             partition_names=self.partition_names,
             **collection_load_params)
-        logging.info(
-            "Successfully connected to Milvus on attempt %d", attempt + 1)
-        return
-      except MilvusException as e:
-        last_exception = e
-        if attempt < max_retries:
-          delay = retry_delay * (retry_backoff_factor**attempt)
-          logging.warning(
-              "Milvus connection attempt %d failed: %s. "
-              "Retrying in %.2f seconds...",
-              attempt + 1,
-              e,
-              delay)
-          time.sleep(delay)
-        else:
-          logging.error(
-              "Failed to connect to Milvus after %d attempts", max_retries + 1)
-          raise last_exception
+        return client
+
+      self._client = retry_with_backoff(
+          connect_and_load,
+          max_retries=max_retries,
+          retry_delay=retry_delay,
+          retry_backoff_factor=retry_backoff_factor,
+          operation_name="Milvus connection and collection load",
+          exception_types=(MilvusException, ))
+    return self
 
   def __call__(self, request: Union[Chunk, List[Chunk]], *args,
                **kwargs) -> List[Tuple[Chunk, Dict[str, Any]]]:
@@ -535,10 +493,7 @@ class 
MilvusSearchEnrichmentHandler(EnrichmentSourceHandler[InputT, OutputT]):
       raise ValueError(
           f"Chunk {chunk.id} missing both text content and sparse embedding "
           "required for keyword search")
-
-    sparse_embedding = self.convert_sparse_embedding_to_milvus_format(
-        chunk.sparse_embedding)
-
+    sparse_embedding = MilvusHelpers.sparse_embedding(chunk.sparse_embedding)
     return chunk.content.text or sparse_embedding
 
   def _get_call_response(
@@ -628,15 +583,3 @@ class 
MilvusSearchEnrichmentHandler(EnrichmentSourceHandler[InputT, OutputT]):
 def join_fn(left: Embedding, right: Dict[str, Any]) -> Embedding:
   left.metadata['enrichment_data'] = right
   return left
-
-
-def unpack_dataclass_with_kwargs(dataclass_instance):
-  # Create a copy of the dataclass's __dict__.
-  params_dict: dict = dataclass_instance.__dict__.copy()
-
-  # Extract the nested kwargs dictionary.
-  nested_kwargs = params_dict.pop('kwargs', {})
-
-  # Merge the dictionaries, with nested_kwargs taking precedence
-  # in case of duplicate keys.
-  return {**params_dict, **nested_kwargs}
diff --git a/sdks/python/apache_beam/ml/rag/enrichment/milvus_search_it_test.py 
b/sdks/python/apache_beam/ml/rag/enrichment/milvus_search_it_test.py
index b3a0dcd5572..34cb3f9050f 100644
--- a/sdks/python/apache_beam/ml/rag/enrichment/milvus_search_it_test.py
+++ b/sdks/python/apache_beam/ml/rag/enrichment/milvus_search_it_test.py
@@ -15,25 +15,13 @@
 # limitations under the License.
 #
 
-import contextlib
-import logging
-import os
 import platform
-import re
-import socket
-import tempfile
 import unittest
-from collections import defaultdict
 from dataclasses import dataclass
 from dataclasses import field
-from typing import Callable
 from typing import Dict
-from typing import List
-from typing import Optional
-from typing import cast
 
 import pytest
-import yaml
 
 import apache_beam as beam
 from apache_beam.ml.rag.types import Chunk
@@ -44,18 +32,12 @@ from apache_beam.testing.util import assert_that
 
 # pylint: disable=ungrouped-imports
 try:
-  from pymilvus import CollectionSchema
   from pymilvus import DataType
   from pymilvus import FieldSchema
   from pymilvus import Function
   from pymilvus import FunctionType
-  from pymilvus import MilvusClient
   from pymilvus import RRFRanker
   from pymilvus.milvus_client import IndexParams
-  from testcontainers.core.config import MAX_TRIES as TC_MAX_TRIES
-  from testcontainers.core.config import testcontainers_config
-  from testcontainers.core.generic import DbContainer
-  from testcontainers.milvus import MilvusContainer
 
   from apache_beam.ml.rag.enrichment.milvus_search import 
HybridSearchParameters
   from apache_beam.ml.rag.enrichment.milvus_search import KeywordSearchMetrics
@@ -66,12 +48,12 @@ try:
   from apache_beam.ml.rag.enrichment.milvus_search import 
MilvusSearchParameters
   from apache_beam.ml.rag.enrichment.milvus_search import VectorSearchMetrics
   from apache_beam.ml.rag.enrichment.milvus_search import 
VectorSearchParameters
+  from apache_beam.ml.rag.test_utils import MilvusTestHelpers
+  from apache_beam.ml.rag.test_utils import VectorDBContainerInfo
   from apache_beam.transforms.enrichment import Enrichment
 except ImportError as e:
   raise unittest.SkipTest(f'Milvus dependencies not installed: {str(e)}')
 
-_LOGGER = logging.getLogger(__name__)
-
 
 def _construct_index_params():
   index_params = IndexParams()
@@ -243,244 +225,6 @@ MILVUS_IT_CONFIG = {
 }
 
 
-@dataclass
-class MilvusDBContainerInfo:
-  container: DbContainer
-  host: str
-  port: int
-  user: Optional[str] = ""
-  password: Optional[str] = ""
-  token: Optional[str] = ""
-  id: Optional[str] = "default"
-
-  @property
-  def uri(self) -> str:
-    return f"http://{self.host}:{self.port}";
-
-
-class CustomMilvusContainer(MilvusContainer):
-  def __init__(
-      self,
-      image: str,
-      service_container_port,
-      healthcheck_container_port,
-      **kwargs,
-  ) -> None:
-    # Skip the parent class's constructor and go straight to
-    # GenericContainer.
-    super(MilvusContainer, self).__init__(image=image, **kwargs)
-    self.port = service_container_port
-    self.healthcheck_port = healthcheck_container_port
-    self.with_exposed_ports(service_container_port, healthcheck_container_port)
-
-    # Get free host ports.
-    service_host_port = MilvusEnrichmentTestHelper.find_free_port()
-    healthcheck_host_port = MilvusEnrichmentTestHelper.find_free_port()
-
-    # Bind container and host ports.
-    self.with_bind_ports(service_container_port, service_host_port)
-    self.with_bind_ports(healthcheck_container_port, healthcheck_host_port)
-    self.cmd = "milvus run standalone"
-
-    # Set environment variables needed for Milvus.
-    envs = {
-        "ETCD_USE_EMBED": "true",
-        "ETCD_DATA_DIR": "/var/lib/milvus/etcd",
-        "COMMON_STORAGETYPE": "local",
-        "METRICS_PORT": str(healthcheck_container_port)
-    }
-    for env, value in envs.items():
-      self.with_env(env, value)
-
-
-class MilvusEnrichmentTestHelper:
-  # IMPORTANT: When upgrading the Milvus server version, ensure the pymilvus
-  # Python SDK client in setup.py is updated to match. Referring to the Milvus
-  # release notes compatibility matrix at
-  # https://milvus.io/docs/release_notes.md or PyPI at
-  # https://pypi.org/project/pymilvus/ for version compatibility.
-  # Example: Milvus v2.6.0 requires pymilvus==2.6.0 (exact match required).
-  @staticmethod
-  def start_db_container(
-      image="milvusdb/milvus:v2.5.10",
-      max_vec_fields=5,
-      vector_client_max_retries=3,
-      tc_max_retries=TC_MAX_TRIES) -> Optional[MilvusDBContainerInfo]:
-    service_container_port = MilvusEnrichmentTestHelper.find_free_port()
-    healthcheck_container_port = MilvusEnrichmentTestHelper.find_free_port()
-    user_yaml_creator = MilvusEnrichmentTestHelper.create_user_yaml
-    with user_yaml_creator(service_container_port, max_vec_fields) as cfg:
-      info = None
-      testcontainers_config.max_tries = tc_max_retries
-      for i in range(vector_client_max_retries):
-        try:
-          vector_db_container = CustomMilvusContainer(
-              image=image,
-              service_container_port=service_container_port,
-              healthcheck_container_port=healthcheck_container_port)
-          vector_db_container = vector_db_container.with_volume_mapping(
-              cfg, "/milvus/configs/user.yaml")
-          vector_db_container.start()
-          host = vector_db_container.get_container_host_ip()
-          port = vector_db_container.get_exposed_port(service_container_port)
-          info = MilvusDBContainerInfo(vector_db_container, host, port)
-          testcontainers_config.max_tries = TC_MAX_TRIES
-          _LOGGER.info(
-              "milvus db container started successfully on %s.", info.uri)
-          break
-        except Exception as e:
-          stdout_logs, stderr_logs = vector_db_container.get_logs()
-          stdout_logs = stdout_logs.decode("utf-8")
-          stderr_logs = stderr_logs.decode("utf-8")
-          _LOGGER.warning(
-              "Retry %d/%d: Failed to start Milvus DB container. Reason: %s. "
-              "STDOUT logs:\n%s\nSTDERR logs:\n%s",
-              i + 1,
-              vector_client_max_retries,
-              e,
-              stdout_logs,
-              stderr_logs)
-          if i == vector_client_max_retries - 1:
-            _LOGGER.error(
-                "Unable to start milvus db container for I/O tests after %d "
-                "retries. Tests cannot proceed. STDOUT logs:\n%s\n"
-                "STDERR logs:\n%s",
-                vector_client_max_retries,
-                stdout_logs,
-                stderr_logs)
-            raise e
-      return info
-
-  @staticmethod
-  def stop_db_container(db_info: MilvusDBContainerInfo):
-    if db_info is None:
-      _LOGGER.warning("Milvus db info is None. Skipping stop operation.")
-      return
-    try:
-      _LOGGER.debug("Stopping milvus db container.")
-      db_info.container.stop()
-      _LOGGER.info("milvus db container stopped successfully.")
-    except Exception as e:
-      _LOGGER.warning(
-          "Error encountered while stopping milvus db container: %s", e)
-
-  @staticmethod
-  def initialize_db_with_data(connc_params: MilvusConnectionParameters):
-    # Open the connection to the milvus db.
-    client = MilvusClient(**connc_params.__dict__)
-
-    # Configure schema.
-    field_schemas: List[FieldSchema] = cast(
-        List[FieldSchema], MILVUS_IT_CONFIG["fields"])
-    schema = CollectionSchema(
-        fields=field_schemas, functions=MILVUS_IT_CONFIG["functions"])
-
-    # Create collection with the schema.
-    collection_name = MILVUS_IT_CONFIG["collection_name"]
-    index_function: Callable[[], IndexParams] = cast(
-        Callable[[], IndexParams], MILVUS_IT_CONFIG["index"])
-    client.create_collection(
-        collection_name=collection_name,
-        schema=schema,
-        index_params=index_function())
-
-    # Assert that collection was created.
-    collection_error = f"Expected collection '{collection_name}' to be 
created."
-    assert client.has_collection(collection_name), collection_error
-
-    # Gather all fields we have excluding 'sparse_embedding_bm25' special 
field.
-    fields = list(map(lambda field: field.name, field_schemas))
-
-    # Prep data for indexing. Currently we can't insert sparse vectors for BM25
-    # sparse embedding field as it would be automatically generated by Milvus
-    # through the registered BM25 function.
-    data_ready_to_index = []
-    for doc in MILVUS_IT_CONFIG["corpus"]:
-      item = {}
-      for field in fields:
-        if field.startswith("dense_embedding"):
-          item[field] = doc["dense_embedding"]
-        elif field == "sparse_embedding_inner_product":
-          item[field] = doc["sparse_embedding"]
-        elif field == "sparse_embedding_bm25":
-          # It is automatically generated by Milvus from the content field.
-          continue
-        else:
-          item[field] = doc[field]
-      data_ready_to_index.append(item)
-
-    # Index data.
-    result = client.insert(
-        collection_name=collection_name, data=data_ready_to_index)
-
-    # Assert that the intended data has been properly indexed.
-    insertion_err = f'failed to insert the {result["insert_count"]} data 
points'
-    assert result["insert_count"] == len(data_ready_to_index), insertion_err
-
-    # Release the collection from memory. It will be loaded lazily when the
-    # enrichment handler is invoked.
-    client.release_collection(collection_name)
-
-    # Close the connection to the Milvus database, as no further preparation
-    # operations are needed  before executing the enrichment handler.
-    client.close()
-
-    return collection_name
-
-  @staticmethod
-  def find_free_port():
-    """Find a free port on the local machine."""
-    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
-      # Bind to port 0, which asks OS to assign a free port.
-      s.bind(('', 0))
-      s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
-      # Return the port number assigned by OS.
-      return s.getsockname()[1]
-
-  @staticmethod
-  @contextlib.contextmanager
-  def create_user_yaml(service_port: int, max_vector_field_num=5):
-    """Creates a temporary user.yaml file for Milvus configuration.
-
-      This user yaml file overrides Milvus default configurations. It sets
-      the Milvus service port to the specified container service port. The
-      default for maxVectorFieldNum is 4, but we need 5
-      (one unique field for each metric).
-
-      Args:
-        service_port: Port number for the Milvus service.
-        max_vector_field_num: Max number of vec fields allowed per collection.
-
-      Yields:
-          str: Path to the created temporary yaml file.
-      """
-    with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml',
-                                     delete=False) as temp_file:
-      # Define the content for user.yaml.
-      user_config = {
-          'proxy': {
-              'maxVectorFieldNum': max_vector_field_num, 'port': service_port
-          },
-          'etcd': {
-              'use': {
-                  'embed': True
-              }, 'data': {
-                  'dir': '/var/lib/milvus/etcd'
-              }
-          }
-      }
-
-      # Write the content to the file.
-      yaml.dump(user_config, temp_file, default_flow_style=False)
-      path = temp_file.name
-
-    try:
-      yield path
-    finally:
-      if os.path.exists(path):
-        os.remove(path)
-
-
 @pytest.mark.require_docker_in_docker
 @unittest.skipUnless(
     platform.system() == "Linux",
@@ -492,25 +236,24 @@ class MilvusEnrichmentTestHelper:
 class TestMilvusSearchEnrichment(unittest.TestCase):
   """Tests for search functionality across all search strategies"""
 
-  _db: MilvusDBContainerInfo
+  _db: VectorDBContainerInfo
 
   @classmethod
   def setUpClass(cls):
-    cls._db = MilvusEnrichmentTestHelper.start_db_container()
+    cls._db = MilvusTestHelpers.start_db_container()
     cls._connection_params = MilvusConnectionParameters(
         uri=cls._db.uri,
         user=cls._db.user,
         password=cls._db.password,
-        db_id=cls._db.id,
-        token=cls._db.token,
-        timeout=60.0)  # Increase timeout to 60s for container startup
+        db_name=cls._db.id,
+        token=cls._db.token)
     cls._collection_load_params = MilvusCollectionLoadParameters()
-    cls._collection_name = MilvusEnrichmentTestHelper.initialize_db_with_data(
-        cls._connection_params)
+    cls._collection_name = MilvusTestHelpers.initialize_db_with_data(
+        cls._connection_params, MILVUS_IT_CONFIG)
 
   @classmethod
   def tearDownClass(cls):
-    MilvusEnrichmentTestHelper.stop_db_container(cls._db)
+    MilvusTestHelpers.stop_db_container(cls._db)
     cls._db = None
 
   def test_invalid_query_on_non_existent_collection(self):
@@ -589,8 +332,8 @@ class TestMilvusSearchEnrichment(unittest.TestCase):
     with TestPipeline() as p:
       result = (p | beam.Create(test_chunks) | Enrichment(handler))
       assert_that(
-          result,
-          lambda actual: assert_chunks_equivalent(actual, expected_chunks))
+          result, lambda actual: MilvusTestHelpers.assert_chunks_equivalent(
+              actual, expected_chunks))
 
   def test_filtered_search_with_cosine_similarity_and_batching(self):
     test_chunks = [
@@ -717,8 +460,8 @@ class TestMilvusSearchEnrichment(unittest.TestCase):
     with TestPipeline() as p:
       result = (p | beam.Create(test_chunks) | Enrichment(handler))
       assert_that(
-          result,
-          lambda actual: assert_chunks_equivalent(actual, expected_chunks))
+          result, lambda actual: MilvusTestHelpers.assert_chunks_equivalent(
+              actual, expected_chunks))
 
   def test_filtered_search_with_bm25_full_text_and_batching(self):
     test_chunks = [
@@ -822,8 +565,8 @@ class TestMilvusSearchEnrichment(unittest.TestCase):
     with TestPipeline() as p:
       result = (p | beam.Create(test_chunks) | Enrichment(handler))
       assert_that(
-          result,
-          lambda actual: assert_chunks_equivalent(actual, expected_chunks))
+          result, lambda actual: MilvusTestHelpers.assert_chunks_equivalent(
+              actual, expected_chunks))
 
   def test_vector_search_with_euclidean_distance(self):
     test_chunks = [
@@ -963,8 +706,8 @@ class TestMilvusSearchEnrichment(unittest.TestCase):
     with TestPipeline() as p:
       result = (p | beam.Create(test_chunks) | Enrichment(handler))
       assert_that(
-          result,
-          lambda actual: assert_chunks_equivalent(actual, expected_chunks))
+          result, lambda actual: MilvusTestHelpers.assert_chunks_equivalent(
+              actual, expected_chunks))
 
   def test_vector_search_with_inner_product_similarity(self):
     test_chunks = [
@@ -1103,8 +846,8 @@ class TestMilvusSearchEnrichment(unittest.TestCase):
     with TestPipeline() as p:
       result = (p | beam.Create(test_chunks) | Enrichment(handler))
       assert_that(
-          result,
-          lambda actual: assert_chunks_equivalent(actual, expected_chunks))
+          result, lambda actual: MilvusTestHelpers.assert_chunks_equivalent(
+              actual, expected_chunks))
 
   def test_keyword_search_with_inner_product_sparse_embedding(self):
     test_chunks = [
@@ -1168,8 +911,8 @@ class TestMilvusSearchEnrichment(unittest.TestCase):
     with TestPipeline() as p:
       result = (p | beam.Create(test_chunks) | Enrichment(handler))
       assert_that(
-          result,
-          lambda actual: assert_chunks_equivalent(actual, expected_chunks))
+          result, lambda actual: MilvusTestHelpers.assert_chunks_equivalent(
+              actual, expected_chunks))
 
   def test_hybrid_search(self):
     test_chunks = [
@@ -1241,134 +984,8 @@ class TestMilvusSearchEnrichment(unittest.TestCase):
     with TestPipeline() as p:
       result = (p | beam.Create(test_chunks) | Enrichment(handler))
       assert_that(
-          result,
-          lambda actual: assert_chunks_equivalent(actual, expected_chunks))
-
-
-def parse_chunk_strings(chunk_str_list: List[str]) -> List[Chunk]:
-  parsed_chunks = []
-
-  # Define safe globals and disable built-in functions for safety.
-  safe_globals = {
-      'Chunk': Chunk,
-      'Content': Content,
-      'Embedding': Embedding,
-      'defaultdict': defaultdict,
-      'list': list,
-      '__builtins__': {}
-  }
-
-  for raw_str in chunk_str_list:
-    try:
-      # replace "<class 'list'>" with actual list reference.
-      cleaned_str = re.sub(
-          r"defaultdict\(<class 'list'>", "defaultdict(list", raw_str)
-
-      # Evaluate string in restricted environment.
-      chunk = eval(cleaned_str, safe_globals)  # pylint: disable=eval-used
-      if isinstance(chunk, Chunk):
-        parsed_chunks.append(chunk)
-      else:
-        raise ValueError("Parsed object is not a Chunk instance")
-    except Exception as e:
-      raise ValueError(f"Error parsing string:\n{raw_str}\n{e}")
-
-  return parsed_chunks
-
-
-def assert_chunks_equivalent(
-    actual_chunks: List[Chunk], expected_chunks: List[Chunk]):
-  """assert_chunks_equivalent checks for presence rather than exact match"""
-  # Sort both lists by ID to ensure consistent ordering.
-  actual_sorted = sorted(actual_chunks, key=lambda c: c.id)
-  expected_sorted = sorted(expected_chunks, key=lambda c: c.id)
-
-  actual_len = len(actual_sorted)
-  expected_len = len(expected_sorted)
-  err_msg = (
-      f"Different number of chunks, actual: {actual_len}, "
-      f"expected: {expected_len}")
-  assert actual_len == expected_len, err_msg
-
-  for actual, expected in zip(actual_sorted, expected_sorted):
-    # Assert that IDs match.
-    assert actual.id == expected.id
-
-    # Assert that dense embeddings match.
-    err_msg = f"Dense embedding mismatch for chunk {actual.id}"
-    assert actual.dense_embedding == expected.dense_embedding, err_msg
-
-    # Assert that sparse embeddings match.
-    err_msg = f"Sparse embedding mismatch for chunk {actual.id}"
-    assert actual.sparse_embedding == expected.sparse_embedding, err_msg
-
-    # Assert that text content match.
-    err_msg = f"Text Content mismatch for chunk {actual.id}"
-    assert actual.content.text == expected.content.text, err_msg
-
-    # For enrichment_data, be more flexible.
-    # If "expected" has values for enrichment_data but actual doesn't, that's
-    # acceptable since vector search results can vary based on many factors
-    # including implementation details, vector database state, and slight
-    # variations in similarity calculations.
-
-    # First ensure the enrichment data key exists.
-    err_msg = f"Missing enrichment_data key in chunk {actual.id}"
-    assert 'enrichment_data' in actual.metadata, err_msg
-
-    # For enrichment_data, ensure consistent ordering of results.
-    actual_data = actual.metadata['enrichment_data']
-    expected_data = expected.metadata['enrichment_data']
-
-    # If actual has enrichment data, then perform detailed validation.
-    if actual_data and actual_data.get('id'):
-      # Validate IDs have consistent ordering.
-      actual_ids = sorted(actual_data['id'])
-      expected_ids = sorted(expected_data['id'])
-      err_msg = f"IDs in enrichment_data don't match for chunk {actual.id}"
-      assert actual_ids == expected_ids, err_msg
-
-      # Ensure the distance key exist.
-      err_msg = f"Missing distance key in metadata {actual.id}"
-      assert 'distance' in actual_data, err_msg
-
-      # Validate distances exist and have same length as IDs.
-      actual_distances = actual_data['distance']
-      expected_distances = expected_data['distance']
-      err_msg = (
-          "Number of distances doesn't match number of IDs for "
-          f"chunk {actual.id}")
-      assert len(actual_distances) == len(expected_distances), err_msg
-
-      # Ensure the fields key exist.
-      err_msg = f"Missing fields key in metadata {actual.id}"
-      assert 'fields' in actual_data, err_msg
-
-      # Validate fields have consistent content.
-      # Sort fields by 'id' to ensure consistent ordering.
-      actual_fields_sorted = sorted(
-          actual_data['fields'], key=lambda f: f.get('id', 0))
-      expected_fields_sorted = sorted(
-          expected_data['fields'], key=lambda f: f.get('id', 0))
-
-      # Compare field IDs.
-      actual_field_ids = [f.get('id') for f in actual_fields_sorted]
-      expected_field_ids = [f.get('id') for f in expected_fields_sorted]
-      err_msg = f"Field IDs don't match for chunk {actual.id}"
-      assert actual_field_ids == expected_field_ids, err_msg
-
-      # Compare field content.
-      for a_f, e_f in zip(actual_fields_sorted, expected_fields_sorted):
-        # Ensure the id key exist.
-        err_msg = f"Missing id key in metadata.fields {actual.id}"
-        assert 'id' in a_f
-
-        err_msg = f"Field ID mismatch chunk {actual.id}"
-        assert a_f['id'] == e_f['id'], err_msg
-
-        # Validate field metadata.
-        err_msg = f"Field Metadata doesn't match for chunk {actual.id}"
-        assert a_f['metadata'] == e_f['metadata'], err_msg
+          result, lambda actual: MilvusTestHelpers.assert_chunks_equivalent(
+              actual, expected_chunks))
 
 
 if __name__ == '__main__':
diff --git a/sdks/python/apache_beam/ml/rag/ingestion/postgres_common.py 
b/sdks/python/apache_beam/ml/rag/ingestion/postgres_common.py
index eca740a4e9c..68afa56e399 100644
--- a/sdks/python/apache_beam/ml/rag/ingestion/postgres_common.py
+++ b/sdks/python/apache_beam/ml/rag/ingestion/postgres_common.py
@@ -30,16 +30,16 @@ from apache_beam.ml.rag.types import Chunk
 
 def chunk_embedding_fn(chunk: Chunk) -> str:
   """Convert embedding to PostgreSQL array string.
-    
+
     Formats dense embedding as a PostgreSQL-compatible array string.
     Example: [1.0, 2.0] -> '{1.0,2.0}'
-    
+
     Args:
         chunk: Input Chunk object.
-    
+
     Returns:
         str: PostgreSQL array string representation of the embedding.
-    
+
     Raises:
         ValueError: If chunk has no dense embedding.
     """
@@ -51,7 +51,7 @@ def chunk_embedding_fn(chunk: Chunk) -> str:
 @dataclass
 class ColumnSpec:
   """Specification for mapping Chunk fields to SQL columns for insertion.
-    
+
     Defines how to extract and format values from Chunks into database columns,
     handling the full pipeline from Python value to SQL insertion.
 
@@ -71,7 +71,7 @@ class ColumnSpec:
             Common examples:
             - "::float[]" for vector arrays
             - "::jsonb" for JSON data
-    
+
     Examples:
         Basic text column (uses standard JDBC type mapping):
         >>> ColumnSpec.text(
@@ -83,7 +83,7 @@ class ColumnSpec:
         Vector column with explicit array casting:
         >>> ColumnSpec.vector(
         ...     column_name="embedding",
-        ...     value_fn=lambda chunk: '{' + 
+        ...     value_fn=lambda chunk: '{' +
         ...         ','.join(map(str, chunk.embedding.dense_embedding)) + '}'
         ... )
         # Results in: INSERT INTO table (embedding) VALUES (?::float[])
@@ -168,17 +168,17 @@ class ColumnSpecsBuilder:
       convert_fn: Optional[Callable[[str], Any]] = None,
       sql_typecast: Optional[str] = None) -> 'ColumnSpecsBuilder':
     """Add ID :class:`.ColumnSpec` with optional type and conversion.
-        
+
         Args:
             column_name: Name for the ID column (defaults to "id")
             python_type: Python type for the column (defaults to str)
             convert_fn: Optional function to convert the chunk ID
                        If None, uses ID as-is
             sql_typecast: Optional SQL type cast
-        
+
         Returns:
             Self for method chaining
-        
+
         Example:
             >>> builder.with_id_spec(
             ...     column_name="doc_id",
@@ -205,17 +205,17 @@ class ColumnSpecsBuilder:
       convert_fn: Optional[Callable[[str], Any]] = None,
       sql_typecast: Optional[str] = None) -> 'ColumnSpecsBuilder':
     """Add content :class:`.ColumnSpec` with optional type and conversion.
-      
+
       Args:
           column_name: Name for the content column (defaults to "content")
           python_type: Python type for the column (defaults to str)
           convert_fn: Optional function to convert the content text
                       If None, uses content text as-is
           sql_typecast: Optional SQL type cast
-      
+
       Returns:
           Self for method chaining
-      
+
       Example:
           >>> builder.with_content_spec(
           ...     column_name="content_length",
@@ -244,17 +244,17 @@ class ColumnSpecsBuilder:
       convert_fn: Optional[Callable[[Dict[str, Any]], Any]] = None,
       sql_typecast: Optional[str] = "::jsonb") -> 'ColumnSpecsBuilder':
     """Add metadata :class:`.ColumnSpec` with optional type and conversion.
-      
+
       Args:
           column_name: Name for the metadata column (defaults to "metadata")
           python_type: Python type for the column (defaults to str)
           convert_fn: Optional function to convert the metadata dictionary
                       If None and python_type is str, converts to JSON string
           sql_typecast: Optional SQL type cast (defaults to "::jsonb")
-      
+
       Returns:
           Self for method chaining
-      
+
       Example:
           >>> builder.with_metadata_spec(
           ...     column_name="meta_tags",
@@ -283,19 +283,19 @@ class ColumnSpecsBuilder:
       convert_fn: Optional[Callable[[List[float]], Any]] = None
   ) -> 'ColumnSpecsBuilder':
     """Add embedding :class:`.ColumnSpec` with optional conversion.
-      
+
       Args:
           column_name: Name for the embedding column (defaults to "embedding")
           convert_fn: Optional function to convert the dense embedding values
                       If None, uses default PostgreSQL array format
-      
+
       Returns:
           Self for method chaining
-      
+
       Example:
           >>> builder.with_embedding_spec(
           ...     column_name="embedding_vector",
-          ...     convert_fn=lambda values: '{' + ','.join(f"{x:.4f}" 
+          ...     convert_fn=lambda values: '{' + ','.join(f"{x:.4f}"
           ...       for x in values) + '}'
           ... )
       """
@@ -330,7 +330,7 @@ class ColumnSpecsBuilder:
                       desired type. If None, value is used as-is
             default: Default value if field is missing from metadata
             sql_typecast: Optional SQL type cast (e.g. "::timestamp")
-        
+
         Returns:
             Self for chaining
 
@@ -385,17 +385,17 @@ class ColumnSpecsBuilder:
 
   def add_custom_column_spec(self, spec: ColumnSpec) -> 'ColumnSpecsBuilder':
     """Add a custom :class:`.ColumnSpec` to the builder.
-    
+
     Use this method when you need complete control over the 
:class:`.ColumnSpec`
     , including custom value extraction and type handling.
-    
+
     Args:
         spec: A :class:`.ColumnSpec` instance defining the column name, type,
             value extraction, and optional SQL type casting.
-    
+
     Returns:
         Self for method chaining
-    
+
     Examples:
         Custom text column from chunk metadata:
 
@@ -430,12 +430,12 @@ class ConflictResolution:
             IGNORE: Skips conflicting records.
         update_fields: Optional list of fields to update on conflict. If None,
             all non-conflict fields are updated.
-        
+
     Examples:
         Simple primary key:
 
         >>> ConflictResolution("id")
-        
+
         Composite key with specific update fields:
 
         >>> ConflictResolution(
@@ -443,7 +443,7 @@ class ConflictResolution:
         ...     action="UPDATE",
         ...     update_fields=["embedding", "content"]
         ... )
-        
+
         Ignore conflicts:
 
         >>> ConflictResolution(
diff --git a/sdks/python/apache_beam/ml/rag/test_utils.py 
b/sdks/python/apache_beam/ml/rag/test_utils.py
new file mode 100644
index 00000000000..f4acb105892
--- /dev/null
+++ b/sdks/python/apache_beam/ml/rag/test_utils.py
@@ -0,0 +1,413 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import contextlib
+import logging
+import os
+import socket
+import tempfile
+import unittest
+from dataclasses import dataclass
+from typing import Callable
+from typing import List
+from typing import Optional
+from typing import cast
+
+from apache_beam.ml.rag.types import Chunk
+from apache_beam.ml.rag.utils import retry_with_backoff
+
+# pylint: disable=ungrouped-imports
+try:
+  import yaml
+  from pymilvus import CollectionSchema
+  from pymilvus import FieldSchema
+  from pymilvus import MilvusClient
+  from pymilvus.exceptions import MilvusException
+  from pymilvus.milvus_client import IndexParams
+  from testcontainers.core.config import testcontainers_config
+  from testcontainers.core.generic import DbContainer
+  from testcontainers.milvus import MilvusContainer
+
+  from apache_beam.ml.rag.enrichment.milvus_search import 
MilvusConnectionParameters
+except ImportError as e:
+  raise unittest.SkipTest(f'RAG test util dependencies not installed: 
{str(e)}')
+
+_LOGGER = logging.getLogger(__name__)
+
+
+@dataclass
+class VectorDBContainerInfo:
+  """Container information for vector database test instances.
+
+  Holds connection details and container reference for testing with
+  vector databases like Milvus in containerized environments.
+  """
+  container: DbContainer
+  host: str
+  port: int
+  user: str = ""
+  password: str = ""
+  token: str = ""
+  id: str = "default"
+
+  @property
+  def uri(self) -> str:
+    return f"http://{self.host}:{self.port}";
+
+
+class TestHelpers:
+  @staticmethod
+  def find_free_port():
+    """Find a free port on the local machine."""
+    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+      # Bind to port 0, which asks OS to assign a free port.
+      s.bind(('', 0))
+      s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+      # Return the port number assigned by OS.
+      return s.getsockname()[1]
+
+
+class CustomMilvusContainer(MilvusContainer):
+  """Custom Milvus container with configurable ports and environment setup.
+
+  Extends MilvusContainer to provide custom port binding and environment
+  configuration for testing with standalone Milvus instances.
+  """
+  def __init__(
+      self,
+      image: str,
+      service_container_port,
+      healthcheck_container_port,
+      **kwargs,
+  ) -> None:
+    # Skip the parent class's constructor and go straight to
+    # GenericContainer.
+    super(MilvusContainer, self).__init__(image=image, **kwargs)
+    self.port = service_container_port
+    self.healthcheck_port = healthcheck_container_port
+    self.with_exposed_ports(service_container_port, healthcheck_container_port)
+
+    # Get free host ports.
+    service_host_port = TestHelpers.find_free_port()
+    healthcheck_host_port = TestHelpers.find_free_port()
+
+    # Bind container and host ports.
+    self.with_bind_ports(service_container_port, service_host_port)
+    self.with_bind_ports(healthcheck_container_port, healthcheck_host_port)
+    self.cmd = "milvus run standalone"
+
+    # Set environment variables needed for Milvus.
+    envs = {
+        "ETCD_USE_EMBED": "true",
+        "ETCD_DATA_DIR": "/var/lib/milvus/etcd",
+        "COMMON_STORAGETYPE": "local",
+        "METRICS_PORT": str(healthcheck_container_port)
+    }
+    for env, value in envs.items():
+      self.with_env(env, value)
+
+
+class MilvusTestHelpers:
+  """Helper utilities for testing Milvus vector database operations.
+
+  Provides static methods for managing test containers, configuration files,
+  and chunk comparison utilities for Milvus-based integration tests.
+  """
+  # IMPORTANT: When upgrading the Milvus server version, ensure the pymilvus
+  # Python SDK client in setup.py is updated to match. Referring to the Milvus
+  # release notes compatibility matrix at
+  # https://milvus.io/docs/release_notes.md or PyPI at
+  # https://pypi.org/project/pymilvus/ for version compatibility.
+  # Example: Milvus v2.6.0 requires pymilvus==2.6.0 (exact match required).
+  @staticmethod
+  def start_db_container(
+      image="milvusdb/milvus:v2.5.10",
+      max_vec_fields=5,
+      vector_client_max_retries=3,
+      tc_max_retries=None) -> Optional[VectorDBContainerInfo]:
+    service_container_port = TestHelpers.find_free_port()
+    healthcheck_container_port = TestHelpers.find_free_port()
+    user_yaml_creator = MilvusTestHelpers.create_user_yaml
+    with user_yaml_creator(service_container_port, max_vec_fields) as cfg:
+      info = None
+      original_tc_max_tries = testcontainers_config.max_tries
+      if tc_max_retries is not None:
+        testcontainers_config.max_tries = tc_max_retries
+      for i in range(vector_client_max_retries):
+        try:
+          vector_db_container = CustomMilvusContainer(
+              image=image,
+              service_container_port=service_container_port,
+              healthcheck_container_port=healthcheck_container_port)
+          vector_db_container = vector_db_container.with_volume_mapping(
+              cfg, "/milvus/configs/user.yaml")
+          vector_db_container.start()
+          host = vector_db_container.get_container_host_ip()
+          port = vector_db_container.get_exposed_port(service_container_port)
+          info = VectorDBContainerInfo(vector_db_container, host, port)
+          _LOGGER.info(
+              "milvus db container started successfully on %s.", info.uri)
+        except Exception as e:
+          stdout_logs, stderr_logs = vector_db_container.get_logs()
+          stdout_logs = stdout_logs.decode("utf-8")
+          stderr_logs = stderr_logs.decode("utf-8")
+          _LOGGER.warning(
+              "Retry %d/%d: Failed to start Milvus DB container. Reason: %s. "
+              "STDOUT logs:\n%s\nSTDERR logs:\n%s",
+              i + 1,
+              vector_client_max_retries,
+              e,
+              stdout_logs,
+              stderr_logs)
+          if i == vector_client_max_retries - 1:
+            _LOGGER.error(
+                "Unable to start milvus db container for I/O tests after %d "
+                "retries. Tests cannot proceed. STDOUT logs:\n%s\n"
+                "STDERR logs:\n%s",
+                vector_client_max_retries,
+                stdout_logs,
+                stderr_logs)
+            raise e
+        finally:
+          testcontainers_config.max_tries = original_tc_max_tries
+      return info
+
+  @staticmethod
+  def stop_db_container(db_info: VectorDBContainerInfo):
+    if db_info is None:
+      _LOGGER.warning("Milvus db info is None. Skipping stop operation.")
+      return
+    _LOGGER.debug("Stopping milvus db container.")
+    db_info.container.stop()
+    _LOGGER.info("milvus db container stopped successfully.")
+
+  @staticmethod
+  def initialize_db_with_data(
+      connc_params: MilvusConnectionParameters, config: dict):
+    # Open the connection to the milvus db with retry.
+    def create_client():
+      return MilvusClient(**connc_params.__dict__)
+
+    client = retry_with_backoff(
+        create_client,
+        max_retries=3,
+        retry_delay=1.0,
+        operation_name="Test Milvus client connection",
+        exception_types=(MilvusException, ))
+
+    # Configure schema.
+    field_schemas: List[FieldSchema] = cast(List[FieldSchema], 
config["fields"])
+    schema = CollectionSchema(
+        fields=field_schemas, functions=config["functions"])
+
+    # Create collection with the schema.
+    collection_name = config["collection_name"]
+    index_function: Callable[[], IndexParams] = cast(
+        Callable[[], IndexParams], config["index"])
+    client.create_collection(
+        collection_name=collection_name,
+        schema=schema,
+        index_params=index_function())
+
+    # Assert that collection was created.
+    collection_error = f"Expected collection '{collection_name}' to be 
created."
+    assert client.has_collection(collection_name), collection_error
+
+    # Gather all fields we have excluding 'sparse_embedding_bm25' special 
field.
+    fields = list(map(lambda field: field.name, field_schemas))
+
+    # Prep data for indexing. Currently we can't insert sparse vectors for BM25
+    # sparse embedding field as it would be automatically generated by Milvus
+    # through the registered BM25 function.
+    data_ready_to_index = []
+    for doc in config["corpus"]:
+      item = {}
+      for field in fields:
+        if field.startswith("dense_embedding"):
+          item[field] = doc["dense_embedding"]
+        elif field == "sparse_embedding_inner_product":
+          item[field] = doc["sparse_embedding"]
+        elif field == "sparse_embedding_bm25":
+          # It is automatically generated by Milvus from the content field.
+          continue
+        else:
+          item[field] = doc[field]
+      data_ready_to_index.append(item)
+
+    # Index data.
+    result = client.insert(
+        collection_name=collection_name, data=data_ready_to_index)
+
+    # Assert that the intended data has been properly indexed.
+    insertion_err = f'failed to insert the {result["insert_count"]} data 
points'
+    assert result["insert_count"] == len(data_ready_to_index), insertion_err
+
+    # Release the collection from memory. It will be loaded lazily when the
+    # enrichment handler is invoked.
+    client.release_collection(collection_name)
+
+    # Close the connection to the Milvus database, as no further preparation
+    # operations are needed  before executing the enrichment handler.
+    client.close()
+
+    return collection_name
+
+  @staticmethod
+  @contextlib.contextmanager
+  def create_user_yaml(service_port: int, max_vector_field_num=5):
+    """Creates a temporary user.yaml file for Milvus configuration.
+
+      This user yaml file overrides Milvus default configurations. It sets
+      the Milvus service port to the specified container service port. The
+      default for maxVectorFieldNum is 4, but we need 5
+      (one unique field for each metric).
+
+      Args:
+        service_port: Port number for the Milvus service.
+        max_vector_field_num: Max number of vec fields allowed per collection.
+
+      Yields:
+          str: Path to the created temporary yaml file.
+      """
+    with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml',
+                                     delete=False) as temp_file:
+      # Define the content for user.yaml.
+      user_config = {
+          'proxy': {
+              'maxVectorFieldNum': max_vector_field_num, 'port': service_port
+          },
+          'etcd': {
+              'use': {
+                  'embed': True
+              }, 'data': {
+                  'dir': '/var/lib/milvus/etcd'
+              }
+          }
+      }
+
+      # Write the content to the file.
+      yaml.dump(user_config, temp_file, default_flow_style=False)
+      path = temp_file.name
+
+    try:
+      yield path
+    finally:
+      if os.path.exists(path):
+        os.remove(path)
+
+  @staticmethod
+  def assert_chunks_equivalent(
+      actual_chunks: List[Chunk], expected_chunks: List[Chunk]):
+    """assert_chunks_equivalent checks for presence rather than exact match"""
+    # Sort both lists by ID to ensure consistent ordering.
+    actual_sorted = sorted(actual_chunks, key=lambda c: c.id)
+    expected_sorted = sorted(expected_chunks, key=lambda c: c.id)
+
+    actual_len = len(actual_sorted)
+    expected_len = len(expected_sorted)
+    err_msg = (
+        f"Different number of chunks, actual: {actual_len}, "
+        f"expected: {expected_len}")
+    assert actual_len == expected_len, err_msg
+
+    for actual, expected in zip(actual_sorted, expected_sorted):
+      # Assert that IDs match.
+      assert actual.id == expected.id
+
+      # Assert that dense embeddings match.
+      err_msg = f"Dense embedding mismatch for chunk {actual.id}"
+      assert actual.dense_embedding == expected.dense_embedding, err_msg
+
+      # Assert that sparse embeddings match.
+      err_msg = f"Sparse embedding mismatch for chunk {actual.id}"
+      assert actual.sparse_embedding == expected.sparse_embedding, err_msg
+
+      # Assert that text content match.
+      err_msg = f"Text Content mismatch for chunk {actual.id}"
+      assert actual.content.text == expected.content.text, err_msg
+
+      # For enrichment_data, be more flexible.
+      # If "expected" has values for enrichment_data but actual doesn't, that's
+      # acceptable since vector search results can vary based on many factors
+      # including implementation details, vector database state, and slight
+      # variations in similarity calculations.
+
+      # First ensure the enrichment data key exists.
+      err_msg = f"Missing enrichment_data key in chunk {actual.id}"
+      assert 'enrichment_data' in actual.metadata, err_msg
+
+      # For enrichment_data, ensure consistent ordering of results.
+      actual_data = actual.metadata['enrichment_data']
+      expected_data = expected.metadata['enrichment_data']
+
+      # If actual has enrichment data, then perform detailed validation.
+      if actual_data:
+        # Ensure the id key exist.
+        err_msg = f"Missing id key in metadata {actual.id}"
+        assert 'id' in actual_data, err_msg
+
+        # Validate IDs have consistent ordering.
+        actual_ids = sorted(actual_data['id'])
+        expected_ids = sorted(expected_data['id'])
+        err_msg = f"IDs in enrichment_data don't match for chunk {actual.id}"
+        assert actual_ids == expected_ids, err_msg
+
+        # Ensure the distance key exist.
+        err_msg = f"Missing distance key in metadata {actual.id}"
+        assert 'distance' in actual_data, err_msg
+
+        # Validate distances exist and have same length as IDs.
+        actual_distances = actual_data['distance']
+        expected_distances = expected_data['distance']
+        err_msg = (
+            "Number of distances doesn't match number of IDs for "
+            f"chunk {actual.id}")
+        assert len(actual_distances) == len(expected_distances), err_msg
+
+        # Ensure the fields key exist.
+        err_msg = f"Missing fields key in metadata {actual.id}"
+        assert 'fields' in actual_data, err_msg
+
+        # Validate fields have consistent content.
+        # Sort fields by 'id' to ensure consistent ordering.
+        actual_fields_sorted = sorted(
+            actual_data['fields'], key=lambda f: f.get('id', 0))
+        expected_fields_sorted = sorted(
+            expected_data['fields'], key=lambda f: f.get('id', 0))
+
+        # Compare field IDs.
+        actual_field_ids = [f.get('id') for f in actual_fields_sorted]
+        expected_field_ids = [f.get('id') for f in expected_fields_sorted]
+        err_msg = f"Field IDs don't match for chunk {actual.id}"
+        assert actual_field_ids == expected_field_ids, err_msg
+
+        # Compare field content.
+        for a_f, e_f in zip(actual_fields_sorted, expected_fields_sorted):
+          # Ensure the id key exist.
+          err_msg = f"Missing id key in metadata.fields {actual.id}"
+          assert 'id' in a_f, err_msg
+
+          err_msg = f"Field ID mismatch chunk {actual.id}"
+          assert a_f['id'] == e_f['id'], err_msg
+
+          # Validate field metadata.
+          err_msg = f"Field Metadata doesn't match for chunk {actual.id}"
+          assert a_f['metadata'] == e_f['metadata'], err_msg
+
+
+if __name__ == '__main__':
+  unittest.main()
diff --git a/sdks/python/apache_beam/ml/rag/utils.py 
b/sdks/python/apache_beam/ml/rag/utils.py
new file mode 100644
index 00000000000..d45e99be0ec
--- /dev/null
+++ b/sdks/python/apache_beam/ml/rag/utils.py
@@ -0,0 +1,224 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+import re
+import time
+import uuid
+from collections import defaultdict
+from dataclasses import dataclass
+from dataclasses import field
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import Type
+
+from apache_beam.ml.rag.types import Chunk
+from apache_beam.ml.rag.types import Content
+from apache_beam.ml.rag.types import Embedding
+
+_LOGGER = logging.getLogger(__name__)
+
+# Default batch size for writing data to Milvus, matching
+# JdbcIO.DEFAULT_BATCH_SIZE.
+DEFAULT_WRITE_BATCH_SIZE = 1000
+
+
+@dataclass
+class MilvusConnectionParameters:
+  """Configurations for establishing connections to Milvus servers.
+
+  Args:
+    uri: URI endpoint for connecting to Milvus server in the format
+      "http(s)://hostname:port".
+    user: Username for authentication. Required if authentication is enabled 
and
+      not using token authentication.
+    password: Password for authentication. Required if authentication is 
enabled
+      and not using token authentication.
+    db_name: Database Name to connect to. Specifies which Milvus database to
+      use. Defaults to 'default'.
+    token: Authentication token as an alternative to username/password.
+    timeout: Connection timeout in seconds. Uses client default if None.
+    kwargs: Optional keyword arguments for additional connection parameters.
+      Enables forward compatibility.
+  """
+  uri: str
+  user: str = field(default_factory=str)
+  password: str = field(default_factory=str)
+  db_name: str = "default"
+  token: str = field(default_factory=str)
+  timeout: Optional[float] = None
+  kwargs: Dict[str, Any] = field(default_factory=dict)
+
+  def __post_init__(self):
+    if not self.uri:
+      raise ValueError("URI must be provided for Milvus connection")
+
+    # Generate unique alias if not provided. One-to-one mapping between alias
+    # and connection - each alias represents exactly one Milvus connection.
+    if "alias" not in self.kwargs:
+      alias = f"milvus_conn_{uuid.uuid4().hex[:8]}"
+      self.kwargs["alias"] = alias
+
+
+class MilvusHelpers:
+  """Utility class providing helper methods for Milvus vector db operations."""
+  @staticmethod
+  def sparse_embedding(
+      sparse_vector: Tuple[List[int],
+                           List[float]]) -> Optional[Dict[int, float]]:
+    if not sparse_vector:
+      return None
+    # Converts sparse embedding from (indices, values) tuple format to
+    # Milvus-compatible values dict format {dimension_index: value, ...}.
+    indices, values = sparse_vector
+    return {int(idx): float(val) for idx, val in zip(indices, values)}
+
+
+def parse_chunk_strings(chunk_str_list: List[str]) -> List[Chunk]:
+  parsed_chunks = []
+
+  # Define safe globals and disable built-in functions for safety.
+  safe_globals = {
+      'Chunk': Chunk,
+      'Content': Content,
+      'Embedding': Embedding,
+      'defaultdict': defaultdict,
+      'list': list,
+      '__builtins__': {}
+  }
+
+  for raw_str in chunk_str_list:
+    try:
+      # replace "<class 'list'>" with actual list reference.
+      cleaned_str = re.sub(
+          r"defaultdict\(<class 'list'>", "defaultdict(list", raw_str)
+
+      # Evaluate string in restricted environment.
+      chunk = eval(cleaned_str, safe_globals)  # pylint: disable=eval-used
+      if isinstance(chunk, Chunk):
+        parsed_chunks.append(chunk)
+      else:
+        raise ValueError("Parsed object is not a Chunk instance")
+    except Exception as e:
+      raise ValueError(f"Error parsing string:\n{raw_str}\n{e}")
+
+  return parsed_chunks
+
+
+def unpack_dataclass_with_kwargs(dataclass_instance):
+  """Unpacks dataclass fields into a flat dict, merging kwargs with precedence.
+
+  Args:
+    dataclass_instance: Dataclass instance to unpack.
+
+  Returns:
+    dict: Flattened dictionary with kwargs taking precedence over fields.
+  """
+  # Create a copy of the dataclass's __dict__.
+  params_dict: dict = dataclass_instance.__dict__.copy()
+
+  # Extract the nested kwargs dictionary.
+  nested_kwargs = params_dict.pop('kwargs', {})
+
+  # Merge the dictionaries, with nested_kwargs taking precedence
+  # in case of duplicate keys.
+  return {**params_dict, **nested_kwargs}
+
+
+def retry_with_backoff(
+    operation: Callable[[], Any],
+    max_retries: int = 3,
+    retry_delay: float = 1.0,
+    retry_backoff_factor: float = 2.0,
+    operation_name: str = "operation",
+    exception_types: Tuple[Type[BaseException], ...] = (Exception, )
+) -> Any:
+  """Executes an operation with retry logic and exponential backoff.
+
+  This is a generic retry utility that can be used for any operation that may
+  fail transiently. It retries the operation with exponential backoff between
+  attempts.
+
+  Note:
+    This utility is designed for one-time setup operations and complements
+    Apache Beam's RequestResponseIO pattern. Use retry_with_backoff() for:
+
+    * Establishing client connections in __enter__() methods (e.g., creating
+      MilvusClient instances, database connections) before processing elements
+    * One-time setup/teardown operations in DoFn lifecycle methods
+    * Operations outside of per-element processing where retry is needed
+
+    For per-element operations (e.g., API calls within Caller.__call__),
+    use RequestResponseIO which already provides automatic retry with
+    exponential backoff, failure handling, caching, and other features.
+    See: https://beam.apache.org/documentation/io/built-in/webapis/
+
+  Args:
+    operation: Callable that performs the operation to retry. Should return
+      the result of the operation.
+    max_retries: Maximum number of retry attempts. Default is 3.
+    retry_delay: Initial delay in seconds between retries. Default is 1.0.
+    retry_backoff_factor: Multiplier for the delay after each retry. Default
+      is 2.0 (exponential backoff).
+    operation_name: Name of the operation for logging purposes. Default is
+      "operation".
+    exception_types: Tuple of exception types to catch and retry. Default is
+      (Exception,) which catches all exceptions.
+
+  Returns:
+    The result of the operation if successful.
+
+  Raises:
+    The last exception encountered if all retry attempts fail.
+
+  Example:
+    >>> def connect_to_service():
+    ...     return service.connect(host="localhost")
+    >>> client = retry_with_backoff(
+    ...     connect_to_service,
+    ...     max_retries=5,
+    ...     retry_delay=2.0,
+    ...     operation_name="service connection")
+  """
+  last_exception = None
+  for attempt in range(max_retries + 1):
+    try:
+      result = operation()
+      _LOGGER.info(
+          "Successfully completed %s on attempt %d",
+          operation_name,
+          attempt + 1)
+      return result
+    except exception_types as e:
+      last_exception = e
+      if attempt < max_retries:
+        delay = retry_delay * (retry_backoff_factor**attempt)
+        _LOGGER.warning(
+            "%s attempt %d failed: %s. Retrying in %.2f seconds...",
+            operation_name,
+            attempt + 1,
+            e,
+            delay)
+        time.sleep(delay)
+      else:
+        _LOGGER.error(
+            "Failed %s after %d attempts", operation_name, max_retries + 1)
+        raise last_exception

Reply via email to