This is an automated email from the ASF dual-hosted git repository.
damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 83ebe731133 [1/3] sdks/python: refactor Milvus-related utilities as
preparation step for Milvus Sink I/O integration (#35708)
83ebe731133 is described below
commit 83ebe73113391c1650680d2665a15849b536e776
Author: Mohamed Awnallah <[email protected]>
AuthorDate: Wed Nov 12 20:50:46 2025 +0200
[1/3] sdks/python: refactor Milvus-related utilities as preparation step
for Milvus Sink I/O integration (#35708)
* sdks/python: replace the deprecated testcontainer max tries
* sdks/python: handle transient testcontainer startup/teardown errors
* sdks/python: bump `testcontainers` py pkg version
* sdks/python: integrate milvus sink I/O
* sdks/python: fix linting issues
* sdks/python: add missing apache beam liscense header for `test_utils.py`
* notebooks/beam-ml: use new refactored code in milvus enrichment handler
* CHANGES.md: update release notes
* sdks/python: mark milvus itests with `require_docker_in_docker` marker
* sdks/python: override milvus db version with the default
* sdsk/python: add missing import in rag utils
* sdks/python: fix linting issue
* rag/ingestion/milvus_search_itest.py: ensure flushing in-memory data
before querying
* sdks/python: fix linting issues
* sdks/python: fix formatting issues
* sdks/python: fix arising linting issue
* rag: reuse `retry_with_backoff` for one-time setup operations
* sdks/python: fix linting issues
* sdks/python: fix py docs CI issue
* sdks/python: fix linting issues
* sdks/python: fix linting issues
* sdks/python: isolate milvus sink integration to be in follow-up PR
* CHANGES.md: remove milvus from release notes in the refactoring PR
* sdks/python: remove `with_sparse_embedding_spec` column specs builder
In this commit, we remove that builder method to remain functional
and be used in the next Milvus sink integration PR
* sdks/python: fix linting issues
* Revert "notebooks/beam-ml: use new refactored code in milvus enrichment
handler"
This reverts commit 461c8fee9d1d4b63b63558d188f88f3e79856309.
* sdks/python: fix linting issues
* sdks/python: fix linting issues
* sdks/python: fix linting issues
* sdks/python: fix linting issues
* CI: fix import errors in CI
* sdks/python: fix linting issues
* sdks/python: fix linting issues
* sdks/python: fix linting issues
* sdks/python: fix linting issues
---
.../transforms/elementwise/enrichment_test.py | 60 +--
.../apache_beam/ml/rag/enrichment/milvus_search.py | 133 ++-----
.../ml/rag/enrichment/milvus_search_it_test.py | 429 ++-------------------
.../ml/rag/ingestion/postgres_common.py | 56 +--
sdks/python/apache_beam/ml/rag/test_utils.py | 413 ++++++++++++++++++++
sdks/python/apache_beam/ml/rag/utils.py | 224 +++++++++++
6 files changed, 759 insertions(+), 556 deletions(-)
diff --git
a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py
b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py
index c8e988a52c5..ed2b0c131e0 100644
---
a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py
+++
b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py
@@ -52,13 +52,11 @@ try:
ConnectionConfig,
CloudSQLConnectionConfig,
ExternalSQLDBConnectionConfig)
- from apache_beam.ml.rag.enrichment.milvus_search import (
- MilvusConnectionParameters)
- from apache_beam.ml.rag.enrichment.milvus_search_it_test import (
- MilvusEnrichmentTestHelper,
- MilvusDBContainerInfo,
- parse_chunk_strings,
- assert_chunks_equivalent)
+ from apache_beam.ml.rag.enrichment.milvus_search import
MilvusConnectionParameters
+ from apache_beam.ml.rag.test_utils import MilvusTestHelpers
+ from apache_beam.ml.rag.test_utils import VectorDBContainerInfo
+ from apache_beam.ml.rag.test_utils import MilvusTestHelpers
+ from apache_beam.ml.rag.utils import parse_chunk_strings
from apache_beam.io.requestresponse import RequestResponseIO
except ImportError as e:
raise unittest.SkipTest(f'Examples dependencies are not installed: {str(e)}')
@@ -69,6 +67,11 @@ class TestContainerStartupError(Exception):
pass
+class TestContainerTeardownError(Exception):
+ """Raised when any test container fails to teardown."""
+ pass
+
+
def validate_enrichment_with_bigtable():
expected = '''[START enrichment_with_bigtable]
Row(sale_id=1, customer_id=1, product_id=1, quantity=1, product={'product_id':
'1', 'product_name': 'pixel 5', 'product_stock': '2'})
@@ -186,7 +189,7 @@ class EnrichmentTest(unittest.TestCase):
output = mock_stdout.getvalue().splitlines()
expected = validate_enrichment_with_external_pg()
self.assertEqual(output, expected)
- except TestContainerStartupError as e:
+ except (TestContainerStartupError, TestContainerTeardownError) as e:
raise unittest.SkipTest(str(e))
except Exception as e:
self.fail(f"Test failed with unexpected error: {e}")
@@ -199,7 +202,7 @@ class EnrichmentTest(unittest.TestCase):
output = mock_stdout.getvalue().splitlines()
expected = validate_enrichment_with_external_mysql()
self.assertEqual(output, expected)
- except TestContainerStartupError as e:
+ except (TestContainerStartupError, TestContainerTeardownError) as e:
raise unittest.SkipTest(str(e))
except Exception as e:
self.fail(f"Test failed with unexpected error: {e}")
@@ -212,7 +215,7 @@ class EnrichmentTest(unittest.TestCase):
output = mock_stdout.getvalue().splitlines()
expected = validate_enrichment_with_external_sqlserver()
self.assertEqual(output, expected)
- except TestContainerStartupError as e:
+ except (TestContainerStartupError, TestContainerTeardownError) as e:
raise unittest.SkipTest(str(e))
except Exception as e:
self.fail(f"Test failed with unexpected error: {e}")
@@ -226,8 +229,8 @@ class EnrichmentTest(unittest.TestCase):
self.maxDiff = None
output = parse_chunk_strings(output)
expected = parse_chunk_strings(expected)
- assert_chunks_equivalent(output, expected)
- except TestContainerStartupError as e:
+ MilvusTestHelpers.assert_chunks_equivalent(output, expected)
+ except (TestContainerStartupError, TestContainerTeardownError) as e:
raise unittest.SkipTest(str(e))
except Exception as e:
self.fail(f"Test failed with unexpected error: {e}")
@@ -257,7 +260,7 @@ class EnrichmentTestHelpers:
@staticmethod
@contextmanager
def milvus_test_context():
- db: Optional[MilvusDBContainerInfo] = None
+ db: Optional[VectorDBContainerInfo] = None
try:
db = EnrichmentTestHelpers.pre_milvus_enrichment()
yield
@@ -370,23 +373,21 @@ class EnrichmentTestHelpers:
os.environ.pop('GOOGLE_CLOUD_SQL_DB_TABLE_ID', None)
@staticmethod
- def pre_milvus_enrichment() -> MilvusDBContainerInfo:
+ def pre_milvus_enrichment() -> VectorDBContainerInfo:
try:
- db = MilvusEnrichmentTestHelper.start_db_container()
+ db = MilvusTestHelpers.start_db_container()
+ connection_params = MilvusConnectionParameters(
+ uri=db.uri,
+ user=db.user,
+ password=db.password,
+ db_id=db.id,
+ token=db.token)
+ collection_name = MilvusTestHelpers.initialize_db_with_data(
+ connection_params)
except Exception as e:
raise TestContainerStartupError(
f"Milvus container failed to start: {str(e)}")
- connection_params = MilvusConnectionParameters(
- uri=db.uri,
- user=db.user,
- password=db.password,
- db_id=db.id,
- token=db.token)
-
- collection_name = MilvusEnrichmentTestHelper.initialize_db_with_data(
- connection_params)
-
# Setup environment variables for db and collection configuration. This
will
# be used downstream by the milvus enrichment handler.
os.environ['MILVUS_VECTOR_DB_URI'] = db.uri
@@ -399,8 +400,13 @@ class EnrichmentTestHelpers:
return db
@staticmethod
- def post_milvus_enrichment(db: MilvusDBContainerInfo):
- MilvusEnrichmentTestHelper.stop_db_container(db)
+ def post_milvus_enrichment(db: VectorDBContainerInfo):
+ try:
+ MilvusTestHelpers.stop_db_container(db)
+ except Exception as e:
+ raise TestContainerTeardownError(
+ f"Milvus container failed to tear down: {str(e)}")
+
os.environ.pop('MILVUS_VECTOR_DB_URI', None)
os.environ.pop('MILVUS_VECTOR_DB_USER', None)
os.environ.pop('MILVUS_VECTOR_DB_PASSWORD', None)
diff --git a/sdks/python/apache_beam/ml/rag/enrichment/milvus_search.py
b/sdks/python/apache_beam/ml/rag/enrichment/milvus_search.py
index 8f631746748..41355e8c10a 100644
--- a/sdks/python/apache_beam/ml/rag/enrichment/milvus_search.py
+++ b/sdks/python/apache_beam/ml/rag/enrichment/milvus_search.py
@@ -32,9 +32,14 @@ from pymilvus import Hit
from pymilvus import Hits
from pymilvus import MilvusClient
from pymilvus import SearchResult
+from pymilvus.exceptions import MilvusException
from apache_beam.ml.rag.types import Chunk
from apache_beam.ml.rag.types import Embedding
+from apache_beam.ml.rag.utils import MilvusConnectionParameters
+from apache_beam.ml.rag.utils import MilvusHelpers
+from apache_beam.ml.rag.utils import retry_with_backoff
+from apache_beam.ml.rag.utils import unpack_dataclass_with_kwargs
from apache_beam.transforms.enrichment import EnrichmentSourceHandler
@@ -104,44 +109,6 @@ class MilvusBaseRanker:
return self.dict().__str__()
-@dataclass
-class MilvusConnectionParameters:
- """Parameters for establishing connections to Milvus servers.
-
- Args:
- uri: URI endpoint for connecting to Milvus server in the format
- "http(s)://hostname:port".
- user: Username for authentication. Required if authentication is enabled
and
- not using token authentication.
- password: Password for authentication. Required if authentication is
enabled
- and not using token authentication.
- db_id: Database ID to connect to. Specifies which Milvus database to use.
- Defaults to 'default'.
- token: Authentication token as an alternative to username/password.
- timeout: Connection timeout in seconds. Uses client default if None.
- max_retries: Maximum number of connection retry attempts. Defaults to 3.
- retry_delay: Initial delay between retries in seconds. Defaults to 1.0.
- retry_backoff_factor: Multiplier for retry delay after each attempt.
- Defaults to 2.0 (exponential backoff).
- kwargs: Optional keyword arguments for additional connection parameters.
- Enables forward compatibility.
- """
- uri: str
- user: str = field(default_factory=str)
- password: str = field(default_factory=str)
- db_id: str = "default"
- token: str = field(default_factory=str)
- timeout: Optional[float] = None
- max_retries: int = 3
- retry_delay: float = 1.0
- retry_backoff_factor: float = 2.0
- kwargs: Dict[str, Any] = field(default_factory=dict)
-
- def __post_init__(self):
- if not self.uri:
- raise ValueError("URI must be provided for Milvus connection")
-
-
@dataclass
class BaseSearchParameters:
"""Base parameters for both vector and keyword search operations.
@@ -361,7 +328,7 @@ class
MilvusSearchEnrichmentHandler(EnrichmentSourceHandler[InputT, OutputT]):
**kwargs):
"""
Example Usage:
- connection_paramters = MilvusConnectionParameters(
+ connection_parameters = MilvusConnectionParameters(
uri="http://localhost:19530")
search_parameters = MilvusSearchParameters(
collection_name="my_collection",
@@ -369,7 +336,7 @@ class
MilvusSearchEnrichmentHandler(EnrichmentSourceHandler[InputT, OutputT]):
collection_load_parameters = MilvusCollectionLoadParameters(
load_fields=["embedding", "metadata"]),
milvus_handler = MilvusSearchEnrichmentHandler(
- connection_paramters,
+ connection_parameters,
search_parameters,
collection_load_parameters=collection_load_parameters,
min_batch_size=10,
@@ -407,52 +374,43 @@ class
MilvusSearchEnrichmentHandler(EnrichmentSourceHandler[InputT, OutputT]):
'min_batch_size': min_batch_size, 'max_batch_size': max_batch_size
}
self.kwargs = kwargs
+ self._client = None
self.join_fn = join_fn
self.use_custom_types = True
def __enter__(self):
- import logging
- import time
-
- from pymilvus.exceptions import MilvusException
-
- connection_params = unpack_dataclass_with_kwargs(
- self._connection_parameters)
- collection_load_params = unpack_dataclass_with_kwargs(
- self._collection_load_parameters)
-
- # Extract retry parameters from connection_params
- max_retries = connection_params.pop('max_retries', 3)
- retry_delay = connection_params.pop('retry_delay', 1.0)
- retry_backoff_factor = connection_params.pop('retry_backoff_factor', 2.0)
-
- # Retry logic for MilvusClient connection
- last_exception = None
- for attempt in range(max_retries + 1):
- try:
- self._client = MilvusClient(**connection_params)
- self._client.load_collection(
+ """Enters the context manager and establishes Milvus connection.
+
+ Returns:
+ Self, enabling use in 'with' statements.
+ """
+ if not self._client:
+ connection_params = unpack_dataclass_with_kwargs(
+ self._connection_parameters)
+ collection_load_params = unpack_dataclass_with_kwargs(
+ self._collection_load_parameters)
+
+ # Extract retry parameters from connection_params.
+ max_retries = connection_params.pop('max_retries', 3)
+ retry_delay = connection_params.pop('retry_delay', 1.0)
+ retry_backoff_factor = connection_params.pop('retry_backoff_factor', 2.0)
+
+ def connect_and_load():
+ client = MilvusClient(**connection_params)
+ client.load_collection(
collection_name=self.collection_name,
partition_names=self.partition_names,
**collection_load_params)
- logging.info(
- "Successfully connected to Milvus on attempt %d", attempt + 1)
- return
- except MilvusException as e:
- last_exception = e
- if attempt < max_retries:
- delay = retry_delay * (retry_backoff_factor**attempt)
- logging.warning(
- "Milvus connection attempt %d failed: %s. "
- "Retrying in %.2f seconds...",
- attempt + 1,
- e,
- delay)
- time.sleep(delay)
- else:
- logging.error(
- "Failed to connect to Milvus after %d attempts", max_retries + 1)
- raise last_exception
+ return client
+
+ self._client = retry_with_backoff(
+ connect_and_load,
+ max_retries=max_retries,
+ retry_delay=retry_delay,
+ retry_backoff_factor=retry_backoff_factor,
+ operation_name="Milvus connection and collection load",
+ exception_types=(MilvusException, ))
+ return self
def __call__(self, request: Union[Chunk, List[Chunk]], *args,
**kwargs) -> List[Tuple[Chunk, Dict[str, Any]]]:
@@ -535,10 +493,7 @@ class
MilvusSearchEnrichmentHandler(EnrichmentSourceHandler[InputT, OutputT]):
raise ValueError(
f"Chunk {chunk.id} missing both text content and sparse embedding "
"required for keyword search")
-
- sparse_embedding = self.convert_sparse_embedding_to_milvus_format(
- chunk.sparse_embedding)
-
+ sparse_embedding = MilvusHelpers.sparse_embedding(chunk.sparse_embedding)
return chunk.content.text or sparse_embedding
def _get_call_response(
@@ -628,15 +583,3 @@ class
MilvusSearchEnrichmentHandler(EnrichmentSourceHandler[InputT, OutputT]):
def join_fn(left: Embedding, right: Dict[str, Any]) -> Embedding:
left.metadata['enrichment_data'] = right
return left
-
-
-def unpack_dataclass_with_kwargs(dataclass_instance):
- # Create a copy of the dataclass's __dict__.
- params_dict: dict = dataclass_instance.__dict__.copy()
-
- # Extract the nested kwargs dictionary.
- nested_kwargs = params_dict.pop('kwargs', {})
-
- # Merge the dictionaries, with nested_kwargs taking precedence
- # in case of duplicate keys.
- return {**params_dict, **nested_kwargs}
diff --git a/sdks/python/apache_beam/ml/rag/enrichment/milvus_search_it_test.py
b/sdks/python/apache_beam/ml/rag/enrichment/milvus_search_it_test.py
index b3a0dcd5572..34cb3f9050f 100644
--- a/sdks/python/apache_beam/ml/rag/enrichment/milvus_search_it_test.py
+++ b/sdks/python/apache_beam/ml/rag/enrichment/milvus_search_it_test.py
@@ -15,25 +15,13 @@
# limitations under the License.
#
-import contextlib
-import logging
-import os
import platform
-import re
-import socket
-import tempfile
import unittest
-from collections import defaultdict
from dataclasses import dataclass
from dataclasses import field
-from typing import Callable
from typing import Dict
-from typing import List
-from typing import Optional
-from typing import cast
import pytest
-import yaml
import apache_beam as beam
from apache_beam.ml.rag.types import Chunk
@@ -44,18 +32,12 @@ from apache_beam.testing.util import assert_that
# pylint: disable=ungrouped-imports
try:
- from pymilvus import CollectionSchema
from pymilvus import DataType
from pymilvus import FieldSchema
from pymilvus import Function
from pymilvus import FunctionType
- from pymilvus import MilvusClient
from pymilvus import RRFRanker
from pymilvus.milvus_client import IndexParams
- from testcontainers.core.config import MAX_TRIES as TC_MAX_TRIES
- from testcontainers.core.config import testcontainers_config
- from testcontainers.core.generic import DbContainer
- from testcontainers.milvus import MilvusContainer
from apache_beam.ml.rag.enrichment.milvus_search import
HybridSearchParameters
from apache_beam.ml.rag.enrichment.milvus_search import KeywordSearchMetrics
@@ -66,12 +48,12 @@ try:
from apache_beam.ml.rag.enrichment.milvus_search import
MilvusSearchParameters
from apache_beam.ml.rag.enrichment.milvus_search import VectorSearchMetrics
from apache_beam.ml.rag.enrichment.milvus_search import
VectorSearchParameters
+ from apache_beam.ml.rag.test_utils import MilvusTestHelpers
+ from apache_beam.ml.rag.test_utils import VectorDBContainerInfo
from apache_beam.transforms.enrichment import Enrichment
except ImportError as e:
raise unittest.SkipTest(f'Milvus dependencies not installed: {str(e)}')
-_LOGGER = logging.getLogger(__name__)
-
def _construct_index_params():
index_params = IndexParams()
@@ -243,244 +225,6 @@ MILVUS_IT_CONFIG = {
}
-@dataclass
-class MilvusDBContainerInfo:
- container: DbContainer
- host: str
- port: int
- user: Optional[str] = ""
- password: Optional[str] = ""
- token: Optional[str] = ""
- id: Optional[str] = "default"
-
- @property
- def uri(self) -> str:
- return f"http://{self.host}:{self.port}"
-
-
-class CustomMilvusContainer(MilvusContainer):
- def __init__(
- self,
- image: str,
- service_container_port,
- healthcheck_container_port,
- **kwargs,
- ) -> None:
- # Skip the parent class's constructor and go straight to
- # GenericContainer.
- super(MilvusContainer, self).__init__(image=image, **kwargs)
- self.port = service_container_port
- self.healthcheck_port = healthcheck_container_port
- self.with_exposed_ports(service_container_port, healthcheck_container_port)
-
- # Get free host ports.
- service_host_port = MilvusEnrichmentTestHelper.find_free_port()
- healthcheck_host_port = MilvusEnrichmentTestHelper.find_free_port()
-
- # Bind container and host ports.
- self.with_bind_ports(service_container_port, service_host_port)
- self.with_bind_ports(healthcheck_container_port, healthcheck_host_port)
- self.cmd = "milvus run standalone"
-
- # Set environment variables needed for Milvus.
- envs = {
- "ETCD_USE_EMBED": "true",
- "ETCD_DATA_DIR": "/var/lib/milvus/etcd",
- "COMMON_STORAGETYPE": "local",
- "METRICS_PORT": str(healthcheck_container_port)
- }
- for env, value in envs.items():
- self.with_env(env, value)
-
-
-class MilvusEnrichmentTestHelper:
- # IMPORTANT: When upgrading the Milvus server version, ensure the pymilvus
- # Python SDK client in setup.py is updated to match. Referring to the Milvus
- # release notes compatibility matrix at
- # https://milvus.io/docs/release_notes.md or PyPI at
- # https://pypi.org/project/pymilvus/ for version compatibility.
- # Example: Milvus v2.6.0 requires pymilvus==2.6.0 (exact match required).
- @staticmethod
- def start_db_container(
- image="milvusdb/milvus:v2.5.10",
- max_vec_fields=5,
- vector_client_max_retries=3,
- tc_max_retries=TC_MAX_TRIES) -> Optional[MilvusDBContainerInfo]:
- service_container_port = MilvusEnrichmentTestHelper.find_free_port()
- healthcheck_container_port = MilvusEnrichmentTestHelper.find_free_port()
- user_yaml_creator = MilvusEnrichmentTestHelper.create_user_yaml
- with user_yaml_creator(service_container_port, max_vec_fields) as cfg:
- info = None
- testcontainers_config.max_tries = tc_max_retries
- for i in range(vector_client_max_retries):
- try:
- vector_db_container = CustomMilvusContainer(
- image=image,
- service_container_port=service_container_port,
- healthcheck_container_port=healthcheck_container_port)
- vector_db_container = vector_db_container.with_volume_mapping(
- cfg, "/milvus/configs/user.yaml")
- vector_db_container.start()
- host = vector_db_container.get_container_host_ip()
- port = vector_db_container.get_exposed_port(service_container_port)
- info = MilvusDBContainerInfo(vector_db_container, host, port)
- testcontainers_config.max_tries = TC_MAX_TRIES
- _LOGGER.info(
- "milvus db container started successfully on %s.", info.uri)
- break
- except Exception as e:
- stdout_logs, stderr_logs = vector_db_container.get_logs()
- stdout_logs = stdout_logs.decode("utf-8")
- stderr_logs = stderr_logs.decode("utf-8")
- _LOGGER.warning(
- "Retry %d/%d: Failed to start Milvus DB container. Reason: %s. "
- "STDOUT logs:\n%s\nSTDERR logs:\n%s",
- i + 1,
- vector_client_max_retries,
- e,
- stdout_logs,
- stderr_logs)
- if i == vector_client_max_retries - 1:
- _LOGGER.error(
- "Unable to start milvus db container for I/O tests after %d "
- "retries. Tests cannot proceed. STDOUT logs:\n%s\n"
- "STDERR logs:\n%s",
- vector_client_max_retries,
- stdout_logs,
- stderr_logs)
- raise e
- return info
-
- @staticmethod
- def stop_db_container(db_info: MilvusDBContainerInfo):
- if db_info is None:
- _LOGGER.warning("Milvus db info is None. Skipping stop operation.")
- return
- try:
- _LOGGER.debug("Stopping milvus db container.")
- db_info.container.stop()
- _LOGGER.info("milvus db container stopped successfully.")
- except Exception as e:
- _LOGGER.warning(
- "Error encountered while stopping milvus db container: %s", e)
-
- @staticmethod
- def initialize_db_with_data(connc_params: MilvusConnectionParameters):
- # Open the connection to the milvus db.
- client = MilvusClient(**connc_params.__dict__)
-
- # Configure schema.
- field_schemas: List[FieldSchema] = cast(
- List[FieldSchema], MILVUS_IT_CONFIG["fields"])
- schema = CollectionSchema(
- fields=field_schemas, functions=MILVUS_IT_CONFIG["functions"])
-
- # Create collection with the schema.
- collection_name = MILVUS_IT_CONFIG["collection_name"]
- index_function: Callable[[], IndexParams] = cast(
- Callable[[], IndexParams], MILVUS_IT_CONFIG["index"])
- client.create_collection(
- collection_name=collection_name,
- schema=schema,
- index_params=index_function())
-
- # Assert that collection was created.
- collection_error = f"Expected collection '{collection_name}' to be
created."
- assert client.has_collection(collection_name), collection_error
-
- # Gather all fields we have excluding 'sparse_embedding_bm25' special
field.
- fields = list(map(lambda field: field.name, field_schemas))
-
- # Prep data for indexing. Currently we can't insert sparse vectors for BM25
- # sparse embedding field as it would be automatically generated by Milvus
- # through the registered BM25 function.
- data_ready_to_index = []
- for doc in MILVUS_IT_CONFIG["corpus"]:
- item = {}
- for field in fields:
- if field.startswith("dense_embedding"):
- item[field] = doc["dense_embedding"]
- elif field == "sparse_embedding_inner_product":
- item[field] = doc["sparse_embedding"]
- elif field == "sparse_embedding_bm25":
- # It is automatically generated by Milvus from the content field.
- continue
- else:
- item[field] = doc[field]
- data_ready_to_index.append(item)
-
- # Index data.
- result = client.insert(
- collection_name=collection_name, data=data_ready_to_index)
-
- # Assert that the intended data has been properly indexed.
- insertion_err = f'failed to insert the {result["insert_count"]} data
points'
- assert result["insert_count"] == len(data_ready_to_index), insertion_err
-
- # Release the collection from memory. It will be loaded lazily when the
- # enrichment handler is invoked.
- client.release_collection(collection_name)
-
- # Close the connection to the Milvus database, as no further preparation
- # operations are needed before executing the enrichment handler.
- client.close()
-
- return collection_name
-
- @staticmethod
- def find_free_port():
- """Find a free port on the local machine."""
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
- # Bind to port 0, which asks OS to assign a free port.
- s.bind(('', 0))
- s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- # Return the port number assigned by OS.
- return s.getsockname()[1]
-
- @staticmethod
- @contextlib.contextmanager
- def create_user_yaml(service_port: int, max_vector_field_num=5):
- """Creates a temporary user.yaml file for Milvus configuration.
-
- This user yaml file overrides Milvus default configurations. It sets
- the Milvus service port to the specified container service port. The
- default for maxVectorFieldNum is 4, but we need 5
- (one unique field for each metric).
-
- Args:
- service_port: Port number for the Milvus service.
- max_vector_field_num: Max number of vec fields allowed per collection.
-
- Yields:
- str: Path to the created temporary yaml file.
- """
- with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml',
- delete=False) as temp_file:
- # Define the content for user.yaml.
- user_config = {
- 'proxy': {
- 'maxVectorFieldNum': max_vector_field_num, 'port': service_port
- },
- 'etcd': {
- 'use': {
- 'embed': True
- }, 'data': {
- 'dir': '/var/lib/milvus/etcd'
- }
- }
- }
-
- # Write the content to the file.
- yaml.dump(user_config, temp_file, default_flow_style=False)
- path = temp_file.name
-
- try:
- yield path
- finally:
- if os.path.exists(path):
- os.remove(path)
-
-
@pytest.mark.require_docker_in_docker
@unittest.skipUnless(
platform.system() == "Linux",
@@ -492,25 +236,24 @@ class MilvusEnrichmentTestHelper:
class TestMilvusSearchEnrichment(unittest.TestCase):
"""Tests for search functionality across all search strategies"""
- _db: MilvusDBContainerInfo
+ _db: VectorDBContainerInfo
@classmethod
def setUpClass(cls):
- cls._db = MilvusEnrichmentTestHelper.start_db_container()
+ cls._db = MilvusTestHelpers.start_db_container()
cls._connection_params = MilvusConnectionParameters(
uri=cls._db.uri,
user=cls._db.user,
password=cls._db.password,
- db_id=cls._db.id,
- token=cls._db.token,
- timeout=60.0) # Increase timeout to 60s for container startup
+ db_name=cls._db.id,
+ token=cls._db.token)
cls._collection_load_params = MilvusCollectionLoadParameters()
- cls._collection_name = MilvusEnrichmentTestHelper.initialize_db_with_data(
- cls._connection_params)
+ cls._collection_name = MilvusTestHelpers.initialize_db_with_data(
+ cls._connection_params, MILVUS_IT_CONFIG)
@classmethod
def tearDownClass(cls):
- MilvusEnrichmentTestHelper.stop_db_container(cls._db)
+ MilvusTestHelpers.stop_db_container(cls._db)
cls._db = None
def test_invalid_query_on_non_existent_collection(self):
@@ -589,8 +332,8 @@ class TestMilvusSearchEnrichment(unittest.TestCase):
with TestPipeline() as p:
result = (p | beam.Create(test_chunks) | Enrichment(handler))
assert_that(
- result,
- lambda actual: assert_chunks_equivalent(actual, expected_chunks))
+ result, lambda actual: MilvusTestHelpers.assert_chunks_equivalent(
+ actual, expected_chunks))
def test_filtered_search_with_cosine_similarity_and_batching(self):
test_chunks = [
@@ -717,8 +460,8 @@ class TestMilvusSearchEnrichment(unittest.TestCase):
with TestPipeline() as p:
result = (p | beam.Create(test_chunks) | Enrichment(handler))
assert_that(
- result,
- lambda actual: assert_chunks_equivalent(actual, expected_chunks))
+ result, lambda actual: MilvusTestHelpers.assert_chunks_equivalent(
+ actual, expected_chunks))
def test_filtered_search_with_bm25_full_text_and_batching(self):
test_chunks = [
@@ -822,8 +565,8 @@ class TestMilvusSearchEnrichment(unittest.TestCase):
with TestPipeline() as p:
result = (p | beam.Create(test_chunks) | Enrichment(handler))
assert_that(
- result,
- lambda actual: assert_chunks_equivalent(actual, expected_chunks))
+ result, lambda actual: MilvusTestHelpers.assert_chunks_equivalent(
+ actual, expected_chunks))
def test_vector_search_with_euclidean_distance(self):
test_chunks = [
@@ -963,8 +706,8 @@ class TestMilvusSearchEnrichment(unittest.TestCase):
with TestPipeline() as p:
result = (p | beam.Create(test_chunks) | Enrichment(handler))
assert_that(
- result,
- lambda actual: assert_chunks_equivalent(actual, expected_chunks))
+ result, lambda actual: MilvusTestHelpers.assert_chunks_equivalent(
+ actual, expected_chunks))
def test_vector_search_with_inner_product_similarity(self):
test_chunks = [
@@ -1103,8 +846,8 @@ class TestMilvusSearchEnrichment(unittest.TestCase):
with TestPipeline() as p:
result = (p | beam.Create(test_chunks) | Enrichment(handler))
assert_that(
- result,
- lambda actual: assert_chunks_equivalent(actual, expected_chunks))
+ result, lambda actual: MilvusTestHelpers.assert_chunks_equivalent(
+ actual, expected_chunks))
def test_keyword_search_with_inner_product_sparse_embedding(self):
test_chunks = [
@@ -1168,8 +911,8 @@ class TestMilvusSearchEnrichment(unittest.TestCase):
with TestPipeline() as p:
result = (p | beam.Create(test_chunks) | Enrichment(handler))
assert_that(
- result,
- lambda actual: assert_chunks_equivalent(actual, expected_chunks))
+ result, lambda actual: MilvusTestHelpers.assert_chunks_equivalent(
+ actual, expected_chunks))
def test_hybrid_search(self):
test_chunks = [
@@ -1241,134 +984,8 @@ class TestMilvusSearchEnrichment(unittest.TestCase):
with TestPipeline() as p:
result = (p | beam.Create(test_chunks) | Enrichment(handler))
assert_that(
- result,
- lambda actual: assert_chunks_equivalent(actual, expected_chunks))
-
-
-def parse_chunk_strings(chunk_str_list: List[str]) -> List[Chunk]:
- parsed_chunks = []
-
- # Define safe globals and disable built-in functions for safety.
- safe_globals = {
- 'Chunk': Chunk,
- 'Content': Content,
- 'Embedding': Embedding,
- 'defaultdict': defaultdict,
- 'list': list,
- '__builtins__': {}
- }
-
- for raw_str in chunk_str_list:
- try:
- # replace "<class 'list'>" with actual list reference.
- cleaned_str = re.sub(
- r"defaultdict\(<class 'list'>", "defaultdict(list", raw_str)
-
- # Evaluate string in restricted environment.
- chunk = eval(cleaned_str, safe_globals) # pylint: disable=eval-used
- if isinstance(chunk, Chunk):
- parsed_chunks.append(chunk)
- else:
- raise ValueError("Parsed object is not a Chunk instance")
- except Exception as e:
- raise ValueError(f"Error parsing string:\n{raw_str}\n{e}")
-
- return parsed_chunks
-
-
-def assert_chunks_equivalent(
- actual_chunks: List[Chunk], expected_chunks: List[Chunk]):
- """assert_chunks_equivalent checks for presence rather than exact match"""
- # Sort both lists by ID to ensure consistent ordering.
- actual_sorted = sorted(actual_chunks, key=lambda c: c.id)
- expected_sorted = sorted(expected_chunks, key=lambda c: c.id)
-
- actual_len = len(actual_sorted)
- expected_len = len(expected_sorted)
- err_msg = (
- f"Different number of chunks, actual: {actual_len}, "
- f"expected: {expected_len}")
- assert actual_len == expected_len, err_msg
-
- for actual, expected in zip(actual_sorted, expected_sorted):
- # Assert that IDs match.
- assert actual.id == expected.id
-
- # Assert that dense embeddings match.
- err_msg = f"Dense embedding mismatch for chunk {actual.id}"
- assert actual.dense_embedding == expected.dense_embedding, err_msg
-
- # Assert that sparse embeddings match.
- err_msg = f"Sparse embedding mismatch for chunk {actual.id}"
- assert actual.sparse_embedding == expected.sparse_embedding, err_msg
-
- # Assert that text content match.
- err_msg = f"Text Content mismatch for chunk {actual.id}"
- assert actual.content.text == expected.content.text, err_msg
-
- # For enrichment_data, be more flexible.
- # If "expected" has values for enrichment_data but actual doesn't, that's
- # acceptable since vector search results can vary based on many factors
- # including implementation details, vector database state, and slight
- # variations in similarity calculations.
-
- # First ensure the enrichment data key exists.
- err_msg = f"Missing enrichment_data key in chunk {actual.id}"
- assert 'enrichment_data' in actual.metadata, err_msg
-
- # For enrichment_data, ensure consistent ordering of results.
- actual_data = actual.metadata['enrichment_data']
- expected_data = expected.metadata['enrichment_data']
-
- # If actual has enrichment data, then perform detailed validation.
- if actual_data and actual_data.get('id'):
- # Validate IDs have consistent ordering.
- actual_ids = sorted(actual_data['id'])
- expected_ids = sorted(expected_data['id'])
- err_msg = f"IDs in enrichment_data don't match for chunk {actual.id}"
- assert actual_ids == expected_ids, err_msg
-
- # Ensure the distance key exist.
- err_msg = f"Missing distance key in metadata {actual.id}"
- assert 'distance' in actual_data, err_msg
-
- # Validate distances exist and have same length as IDs.
- actual_distances = actual_data['distance']
- expected_distances = expected_data['distance']
- err_msg = (
- "Number of distances doesn't match number of IDs for "
- f"chunk {actual.id}")
- assert len(actual_distances) == len(expected_distances), err_msg
-
- # Ensure the fields key exist.
- err_msg = f"Missing fields key in metadata {actual.id}"
- assert 'fields' in actual_data, err_msg
-
- # Validate fields have consistent content.
- # Sort fields by 'id' to ensure consistent ordering.
- actual_fields_sorted = sorted(
- actual_data['fields'], key=lambda f: f.get('id', 0))
- expected_fields_sorted = sorted(
- expected_data['fields'], key=lambda f: f.get('id', 0))
-
- # Compare field IDs.
- actual_field_ids = [f.get('id') for f in actual_fields_sorted]
- expected_field_ids = [f.get('id') for f in expected_fields_sorted]
- err_msg = f"Field IDs don't match for chunk {actual.id}"
- assert actual_field_ids == expected_field_ids, err_msg
-
- # Compare field content.
- for a_f, e_f in zip(actual_fields_sorted, expected_fields_sorted):
- # Ensure the id key exist.
- err_msg = f"Missing id key in metadata.fields {actual.id}"
- assert 'id' in a_f
-
- err_msg = f"Field ID mismatch chunk {actual.id}"
- assert a_f['id'] == e_f['id'], err_msg
-
- # Validate field metadata.
- err_msg = f"Field Metadata doesn't match for chunk {actual.id}"
- assert a_f['metadata'] == e_f['metadata'], err_msg
+ result, lambda actual: MilvusTestHelpers.assert_chunks_equivalent(
+ actual, expected_chunks))
if __name__ == '__main__':
diff --git a/sdks/python/apache_beam/ml/rag/ingestion/postgres_common.py
b/sdks/python/apache_beam/ml/rag/ingestion/postgres_common.py
index eca740a4e9c..68afa56e399 100644
--- a/sdks/python/apache_beam/ml/rag/ingestion/postgres_common.py
+++ b/sdks/python/apache_beam/ml/rag/ingestion/postgres_common.py
@@ -30,16 +30,16 @@ from apache_beam.ml.rag.types import Chunk
def chunk_embedding_fn(chunk: Chunk) -> str:
"""Convert embedding to PostgreSQL array string.
-
+
Formats dense embedding as a PostgreSQL-compatible array string.
Example: [1.0, 2.0] -> '{1.0,2.0}'
-
+
Args:
chunk: Input Chunk object.
-
+
Returns:
str: PostgreSQL array string representation of the embedding.
-
+
Raises:
ValueError: If chunk has no dense embedding.
"""
@@ -51,7 +51,7 @@ def chunk_embedding_fn(chunk: Chunk) -> str:
@dataclass
class ColumnSpec:
"""Specification for mapping Chunk fields to SQL columns for insertion.
-
+
Defines how to extract and format values from Chunks into database columns,
handling the full pipeline from Python value to SQL insertion.
@@ -71,7 +71,7 @@ class ColumnSpec:
Common examples:
- "::float[]" for vector arrays
- "::jsonb" for JSON data
-
+
Examples:
Basic text column (uses standard JDBC type mapping):
>>> ColumnSpec.text(
@@ -83,7 +83,7 @@ class ColumnSpec:
Vector column with explicit array casting:
>>> ColumnSpec.vector(
... column_name="embedding",
- ... value_fn=lambda chunk: '{' +
+ ... value_fn=lambda chunk: '{' +
... ','.join(map(str, chunk.embedding.dense_embedding)) + '}'
... )
# Results in: INSERT INTO table (embedding) VALUES (?::float[])
@@ -168,17 +168,17 @@ class ColumnSpecsBuilder:
convert_fn: Optional[Callable[[str], Any]] = None,
sql_typecast: Optional[str] = None) -> 'ColumnSpecsBuilder':
"""Add ID :class:`.ColumnSpec` with optional type and conversion.
-
+
Args:
column_name: Name for the ID column (defaults to "id")
python_type: Python type for the column (defaults to str)
convert_fn: Optional function to convert the chunk ID
If None, uses ID as-is
sql_typecast: Optional SQL type cast
-
+
Returns:
Self for method chaining
-
+
Example:
>>> builder.with_id_spec(
... column_name="doc_id",
@@ -205,17 +205,17 @@ class ColumnSpecsBuilder:
convert_fn: Optional[Callable[[str], Any]] = None,
sql_typecast: Optional[str] = None) -> 'ColumnSpecsBuilder':
"""Add content :class:`.ColumnSpec` with optional type and conversion.
-
+
Args:
column_name: Name for the content column (defaults to "content")
python_type: Python type for the column (defaults to str)
convert_fn: Optional function to convert the content text
If None, uses content text as-is
sql_typecast: Optional SQL type cast
-
+
Returns:
Self for method chaining
-
+
Example:
>>> builder.with_content_spec(
... column_name="content_length",
@@ -244,17 +244,17 @@ class ColumnSpecsBuilder:
convert_fn: Optional[Callable[[Dict[str, Any]], Any]] = None,
sql_typecast: Optional[str] = "::jsonb") -> 'ColumnSpecsBuilder':
"""Add metadata :class:`.ColumnSpec` with optional type and conversion.
-
+
Args:
column_name: Name for the metadata column (defaults to "metadata")
python_type: Python type for the column (defaults to str)
convert_fn: Optional function to convert the metadata dictionary
If None and python_type is str, converts to JSON string
sql_typecast: Optional SQL type cast (defaults to "::jsonb")
-
+
Returns:
Self for method chaining
-
+
Example:
>>> builder.with_metadata_spec(
... column_name="meta_tags",
@@ -283,19 +283,19 @@ class ColumnSpecsBuilder:
convert_fn: Optional[Callable[[List[float]], Any]] = None
) -> 'ColumnSpecsBuilder':
"""Add embedding :class:`.ColumnSpec` with optional conversion.
-
+
Args:
column_name: Name for the embedding column (defaults to "embedding")
convert_fn: Optional function to convert the dense embedding values
If None, uses default PostgreSQL array format
-
+
Returns:
Self for method chaining
-
+
Example:
>>> builder.with_embedding_spec(
... column_name="embedding_vector",
- ... convert_fn=lambda values: '{' + ','.join(f"{x:.4f}"
+ ... convert_fn=lambda values: '{' + ','.join(f"{x:.4f}"
... for x in values) + '}'
... )
"""
@@ -330,7 +330,7 @@ class ColumnSpecsBuilder:
desired type. If None, value is used as-is
default: Default value if field is missing from metadata
sql_typecast: Optional SQL type cast (e.g. "::timestamp")
-
+
Returns:
Self for chaining
@@ -385,17 +385,17 @@ class ColumnSpecsBuilder:
def add_custom_column_spec(self, spec: ColumnSpec) -> 'ColumnSpecsBuilder':
"""Add a custom :class:`.ColumnSpec` to the builder.
-
+
Use this method when you need complete control over the
:class:`.ColumnSpec`
, including custom value extraction and type handling.
-
+
Args:
spec: A :class:`.ColumnSpec` instance defining the column name, type,
value extraction, and optional SQL type casting.
-
+
Returns:
Self for method chaining
-
+
Examples:
Custom text column from chunk metadata:
@@ -430,12 +430,12 @@ class ConflictResolution:
IGNORE: Skips conflicting records.
update_fields: Optional list of fields to update on conflict. If None,
all non-conflict fields are updated.
-
+
Examples:
Simple primary key:
>>> ConflictResolution("id")
-
+
Composite key with specific update fields:
>>> ConflictResolution(
@@ -443,7 +443,7 @@ class ConflictResolution:
... action="UPDATE",
... update_fields=["embedding", "content"]
... )
-
+
Ignore conflicts:
>>> ConflictResolution(
diff --git a/sdks/python/apache_beam/ml/rag/test_utils.py
b/sdks/python/apache_beam/ml/rag/test_utils.py
new file mode 100644
index 00000000000..f4acb105892
--- /dev/null
+++ b/sdks/python/apache_beam/ml/rag/test_utils.py
@@ -0,0 +1,413 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import contextlib
+import logging
+import os
+import socket
+import tempfile
+import unittest
+from dataclasses import dataclass
+from typing import Callable
+from typing import List
+from typing import Optional
+from typing import cast
+
+from apache_beam.ml.rag.types import Chunk
+from apache_beam.ml.rag.utils import retry_with_backoff
+
+# pylint: disable=ungrouped-imports
+try:
+ import yaml
+ from pymilvus import CollectionSchema
+ from pymilvus import FieldSchema
+ from pymilvus import MilvusClient
+ from pymilvus.exceptions import MilvusException
+ from pymilvus.milvus_client import IndexParams
+ from testcontainers.core.config import testcontainers_config
+ from testcontainers.core.generic import DbContainer
+ from testcontainers.milvus import MilvusContainer
+
+ from apache_beam.ml.rag.enrichment.milvus_search import
MilvusConnectionParameters
+except ImportError as e:
+ raise unittest.SkipTest(f'RAG test util dependencies not installed:
{str(e)}')
+
+_LOGGER = logging.getLogger(__name__)
+
+
+@dataclass
+class VectorDBContainerInfo:
+ """Container information for vector database test instances.
+
+ Holds connection details and container reference for testing with
+ vector databases like Milvus in containerized environments.
+ """
+ container: DbContainer
+ host: str
+ port: int
+ user: str = ""
+ password: str = ""
+ token: str = ""
+ id: str = "default"
+
+ @property
+ def uri(self) -> str:
+ return f"http://{self.host}:{self.port}"
+
+
+class TestHelpers:
+ @staticmethod
+ def find_free_port():
+ """Find a free port on the local machine."""
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+ # Bind to port 0, which asks OS to assign a free port.
+ s.bind(('', 0))
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ # Return the port number assigned by OS.
+ return s.getsockname()[1]
+
+
+class CustomMilvusContainer(MilvusContainer):
+ """Custom Milvus container with configurable ports and environment setup.
+
+ Extends MilvusContainer to provide custom port binding and environment
+ configuration for testing with standalone Milvus instances.
+ """
+ def __init__(
+ self,
+ image: str,
+ service_container_port,
+ healthcheck_container_port,
+ **kwargs,
+ ) -> None:
+ # Skip the parent class's constructor and go straight to
+ # GenericContainer.
+ super(MilvusContainer, self).__init__(image=image, **kwargs)
+ self.port = service_container_port
+ self.healthcheck_port = healthcheck_container_port
+ self.with_exposed_ports(service_container_port, healthcheck_container_port)
+
+ # Get free host ports.
+ service_host_port = TestHelpers.find_free_port()
+ healthcheck_host_port = TestHelpers.find_free_port()
+
+ # Bind container and host ports.
+ self.with_bind_ports(service_container_port, service_host_port)
+ self.with_bind_ports(healthcheck_container_port, healthcheck_host_port)
+ self.cmd = "milvus run standalone"
+
+ # Set environment variables needed for Milvus.
+ envs = {
+ "ETCD_USE_EMBED": "true",
+ "ETCD_DATA_DIR": "/var/lib/milvus/etcd",
+ "COMMON_STORAGETYPE": "local",
+ "METRICS_PORT": str(healthcheck_container_port)
+ }
+ for env, value in envs.items():
+ self.with_env(env, value)
+
+
+class MilvusTestHelpers:
+ """Helper utilities for testing Milvus vector database operations.
+
+ Provides static methods for managing test containers, configuration files,
+ and chunk comparison utilities for Milvus-based integration tests.
+ """
+ # IMPORTANT: When upgrading the Milvus server version, ensure the pymilvus
+ # Python SDK client in setup.py is updated to match. Referring to the Milvus
+ # release notes compatibility matrix at
+ # https://milvus.io/docs/release_notes.md or PyPI at
+ # https://pypi.org/project/pymilvus/ for version compatibility.
+ # Example: Milvus v2.6.0 requires pymilvus==2.6.0 (exact match required).
+ @staticmethod
+ def start_db_container(
+ image="milvusdb/milvus:v2.5.10",
+ max_vec_fields=5,
+ vector_client_max_retries=3,
+ tc_max_retries=None) -> Optional[VectorDBContainerInfo]:
+ service_container_port = TestHelpers.find_free_port()
+ healthcheck_container_port = TestHelpers.find_free_port()
+ user_yaml_creator = MilvusTestHelpers.create_user_yaml
+ with user_yaml_creator(service_container_port, max_vec_fields) as cfg:
+ info = None
+ original_tc_max_tries = testcontainers_config.max_tries
+ if tc_max_retries is not None:
+ testcontainers_config.max_tries = tc_max_retries
+ for i in range(vector_client_max_retries):
+ try:
+ vector_db_container = CustomMilvusContainer(
+ image=image,
+ service_container_port=service_container_port,
+ healthcheck_container_port=healthcheck_container_port)
+ vector_db_container = vector_db_container.with_volume_mapping(
+ cfg, "/milvus/configs/user.yaml")
+ vector_db_container.start()
+ host = vector_db_container.get_container_host_ip()
+ port = vector_db_container.get_exposed_port(service_container_port)
+ info = VectorDBContainerInfo(vector_db_container, host, port)
+ _LOGGER.info(
+ "milvus db container started successfully on %s.", info.uri)
+ except Exception as e:
+ stdout_logs, stderr_logs = vector_db_container.get_logs()
+ stdout_logs = stdout_logs.decode("utf-8")
+ stderr_logs = stderr_logs.decode("utf-8")
+ _LOGGER.warning(
+ "Retry %d/%d: Failed to start Milvus DB container. Reason: %s. "
+ "STDOUT logs:\n%s\nSTDERR logs:\n%s",
+ i + 1,
+ vector_client_max_retries,
+ e,
+ stdout_logs,
+ stderr_logs)
+ if i == vector_client_max_retries - 1:
+ _LOGGER.error(
+ "Unable to start milvus db container for I/O tests after %d "
+ "retries. Tests cannot proceed. STDOUT logs:\n%s\n"
+ "STDERR logs:\n%s",
+ vector_client_max_retries,
+ stdout_logs,
+ stderr_logs)
+ raise e
+ finally:
+ testcontainers_config.max_tries = original_tc_max_tries
+ return info
+
+ @staticmethod
+ def stop_db_container(db_info: VectorDBContainerInfo):
+ if db_info is None:
+ _LOGGER.warning("Milvus db info is None. Skipping stop operation.")
+ return
+ _LOGGER.debug("Stopping milvus db container.")
+ db_info.container.stop()
+ _LOGGER.info("milvus db container stopped successfully.")
+
+ @staticmethod
+ def initialize_db_with_data(
+ connc_params: MilvusConnectionParameters, config: dict):
+ # Open the connection to the milvus db with retry.
+ def create_client():
+ return MilvusClient(**connc_params.__dict__)
+
+ client = retry_with_backoff(
+ create_client,
+ max_retries=3,
+ retry_delay=1.0,
+ operation_name="Test Milvus client connection",
+ exception_types=(MilvusException, ))
+
+ # Configure schema.
+ field_schemas: List[FieldSchema] = cast(List[FieldSchema],
config["fields"])
+ schema = CollectionSchema(
+ fields=field_schemas, functions=config["functions"])
+
+ # Create collection with the schema.
+ collection_name = config["collection_name"]
+ index_function: Callable[[], IndexParams] = cast(
+ Callable[[], IndexParams], config["index"])
+ client.create_collection(
+ collection_name=collection_name,
+ schema=schema,
+ index_params=index_function())
+
+ # Assert that collection was created.
+ collection_error = f"Expected collection '{collection_name}' to be
created."
+ assert client.has_collection(collection_name), collection_error
+
+ # Gather all fields we have excluding 'sparse_embedding_bm25' special
field.
+ fields = list(map(lambda field: field.name, field_schemas))
+
+ # Prep data for indexing. Currently we can't insert sparse vectors for BM25
+ # sparse embedding field as it would be automatically generated by Milvus
+ # through the registered BM25 function.
+ data_ready_to_index = []
+ for doc in config["corpus"]:
+ item = {}
+ for field in fields:
+ if field.startswith("dense_embedding"):
+ item[field] = doc["dense_embedding"]
+ elif field == "sparse_embedding_inner_product":
+ item[field] = doc["sparse_embedding"]
+ elif field == "sparse_embedding_bm25":
+ # It is automatically generated by Milvus from the content field.
+ continue
+ else:
+ item[field] = doc[field]
+ data_ready_to_index.append(item)
+
+ # Index data.
+ result = client.insert(
+ collection_name=collection_name, data=data_ready_to_index)
+
+ # Assert that the intended data has been properly indexed.
+ insertion_err = f'failed to insert the {result["insert_count"]} data
points'
+ assert result["insert_count"] == len(data_ready_to_index), insertion_err
+
+ # Release the collection from memory. It will be loaded lazily when the
+ # enrichment handler is invoked.
+ client.release_collection(collection_name)
+
+ # Close the connection to the Milvus database, as no further preparation
+ # operations are needed before executing the enrichment handler.
+ client.close()
+
+ return collection_name
+
+ @staticmethod
+ @contextlib.contextmanager
+ def create_user_yaml(service_port: int, max_vector_field_num=5):
+ """Creates a temporary user.yaml file for Milvus configuration.
+
+ This user yaml file overrides Milvus default configurations. It sets
+ the Milvus service port to the specified container service port. The
+ default for maxVectorFieldNum is 4, but we need 5
+ (one unique field for each metric).
+
+ Args:
+ service_port: Port number for the Milvus service.
+ max_vector_field_num: Max number of vec fields allowed per collection.
+
+ Yields:
+ str: Path to the created temporary yaml file.
+ """
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml',
+ delete=False) as temp_file:
+ # Define the content for user.yaml.
+ user_config = {
+ 'proxy': {
+ 'maxVectorFieldNum': max_vector_field_num, 'port': service_port
+ },
+ 'etcd': {
+ 'use': {
+ 'embed': True
+ }, 'data': {
+ 'dir': '/var/lib/milvus/etcd'
+ }
+ }
+ }
+
+ # Write the content to the file.
+ yaml.dump(user_config, temp_file, default_flow_style=False)
+ path = temp_file.name
+
+ try:
+ yield path
+ finally:
+ if os.path.exists(path):
+ os.remove(path)
+
+ @staticmethod
+ def assert_chunks_equivalent(
+ actual_chunks: List[Chunk], expected_chunks: List[Chunk]):
+ """assert_chunks_equivalent checks for presence rather than exact match"""
+ # Sort both lists by ID to ensure consistent ordering.
+ actual_sorted = sorted(actual_chunks, key=lambda c: c.id)
+ expected_sorted = sorted(expected_chunks, key=lambda c: c.id)
+
+ actual_len = len(actual_sorted)
+ expected_len = len(expected_sorted)
+ err_msg = (
+ f"Different number of chunks, actual: {actual_len}, "
+ f"expected: {expected_len}")
+ assert actual_len == expected_len, err_msg
+
+ for actual, expected in zip(actual_sorted, expected_sorted):
+ # Assert that IDs match.
+ assert actual.id == expected.id
+
+ # Assert that dense embeddings match.
+ err_msg = f"Dense embedding mismatch for chunk {actual.id}"
+ assert actual.dense_embedding == expected.dense_embedding, err_msg
+
+ # Assert that sparse embeddings match.
+ err_msg = f"Sparse embedding mismatch for chunk {actual.id}"
+ assert actual.sparse_embedding == expected.sparse_embedding, err_msg
+
+ # Assert that text content match.
+ err_msg = f"Text Content mismatch for chunk {actual.id}"
+ assert actual.content.text == expected.content.text, err_msg
+
+ # For enrichment_data, be more flexible.
+ # If "expected" has values for enrichment_data but actual doesn't, that's
+ # acceptable since vector search results can vary based on many factors
+ # including implementation details, vector database state, and slight
+ # variations in similarity calculations.
+
+ # First ensure the enrichment data key exists.
+ err_msg = f"Missing enrichment_data key in chunk {actual.id}"
+ assert 'enrichment_data' in actual.metadata, err_msg
+
+ # For enrichment_data, ensure consistent ordering of results.
+ actual_data = actual.metadata['enrichment_data']
+ expected_data = expected.metadata['enrichment_data']
+
+ # If actual has enrichment data, then perform detailed validation.
+ if actual_data:
+ # Ensure the id key exist.
+ err_msg = f"Missing id key in metadata {actual.id}"
+ assert 'id' in actual_data, err_msg
+
+ # Validate IDs have consistent ordering.
+ actual_ids = sorted(actual_data['id'])
+ expected_ids = sorted(expected_data['id'])
+ err_msg = f"IDs in enrichment_data don't match for chunk {actual.id}"
+ assert actual_ids == expected_ids, err_msg
+
+ # Ensure the distance key exist.
+ err_msg = f"Missing distance key in metadata {actual.id}"
+ assert 'distance' in actual_data, err_msg
+
+ # Validate distances exist and have same length as IDs.
+ actual_distances = actual_data['distance']
+ expected_distances = expected_data['distance']
+ err_msg = (
+ "Number of distances doesn't match number of IDs for "
+ f"chunk {actual.id}")
+ assert len(actual_distances) == len(expected_distances), err_msg
+
+ # Ensure the fields key exist.
+ err_msg = f"Missing fields key in metadata {actual.id}"
+ assert 'fields' in actual_data, err_msg
+
+ # Validate fields have consistent content.
+ # Sort fields by 'id' to ensure consistent ordering.
+ actual_fields_sorted = sorted(
+ actual_data['fields'], key=lambda f: f.get('id', 0))
+ expected_fields_sorted = sorted(
+ expected_data['fields'], key=lambda f: f.get('id', 0))
+
+ # Compare field IDs.
+ actual_field_ids = [f.get('id') for f in actual_fields_sorted]
+ expected_field_ids = [f.get('id') for f in expected_fields_sorted]
+ err_msg = f"Field IDs don't match for chunk {actual.id}"
+ assert actual_field_ids == expected_field_ids, err_msg
+
+ # Compare field content.
+ for a_f, e_f in zip(actual_fields_sorted, expected_fields_sorted):
+ # Ensure the id key exist.
+ err_msg = f"Missing id key in metadata.fields {actual.id}"
+ assert 'id' in a_f, err_msg
+
+ err_msg = f"Field ID mismatch chunk {actual.id}"
+ assert a_f['id'] == e_f['id'], err_msg
+
+ # Validate field metadata.
+ err_msg = f"Field Metadata doesn't match for chunk {actual.id}"
+ assert a_f['metadata'] == e_f['metadata'], err_msg
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/sdks/python/apache_beam/ml/rag/utils.py
b/sdks/python/apache_beam/ml/rag/utils.py
new file mode 100644
index 00000000000..d45e99be0ec
--- /dev/null
+++ b/sdks/python/apache_beam/ml/rag/utils.py
@@ -0,0 +1,224 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+import re
+import time
+import uuid
+from collections import defaultdict
+from dataclasses import dataclass
+from dataclasses import field
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import Type
+
+from apache_beam.ml.rag.types import Chunk
+from apache_beam.ml.rag.types import Content
+from apache_beam.ml.rag.types import Embedding
+
+_LOGGER = logging.getLogger(__name__)
+
+# Default batch size for writing data to Milvus, matching
+# JdbcIO.DEFAULT_BATCH_SIZE.
+DEFAULT_WRITE_BATCH_SIZE = 1000
+
+
+@dataclass
+class MilvusConnectionParameters:
+ """Configurations for establishing connections to Milvus servers.
+
+ Args:
+ uri: URI endpoint for connecting to Milvus server in the format
+ "http(s)://hostname:port".
+ user: Username for authentication. Required if authentication is enabled
and
+ not using token authentication.
+ password: Password for authentication. Required if authentication is
enabled
+ and not using token authentication.
+ db_name: Database Name to connect to. Specifies which Milvus database to
+ use. Defaults to 'default'.
+ token: Authentication token as an alternative to username/password.
+ timeout: Connection timeout in seconds. Uses client default if None.
+ kwargs: Optional keyword arguments for additional connection parameters.
+ Enables forward compatibility.
+ """
+ uri: str
+ user: str = field(default_factory=str)
+ password: str = field(default_factory=str)
+ db_name: str = "default"
+ token: str = field(default_factory=str)
+ timeout: Optional[float] = None
+ kwargs: Dict[str, Any] = field(default_factory=dict)
+
+ def __post_init__(self):
+ if not self.uri:
+ raise ValueError("URI must be provided for Milvus connection")
+
+ # Generate unique alias if not provided. One-to-one mapping between alias
+ # and connection - each alias represents exactly one Milvus connection.
+ if "alias" not in self.kwargs:
+ alias = f"milvus_conn_{uuid.uuid4().hex[:8]}"
+ self.kwargs["alias"] = alias
+
+
+class MilvusHelpers:
+ """Utility class providing helper methods for Milvus vector db operations."""
+ @staticmethod
+ def sparse_embedding(
+ sparse_vector: Tuple[List[int],
+ List[float]]) -> Optional[Dict[int, float]]:
+ if not sparse_vector:
+ return None
+ # Converts sparse embedding from (indices, values) tuple format to
+ # Milvus-compatible values dict format {dimension_index: value, ...}.
+ indices, values = sparse_vector
+ return {int(idx): float(val) for idx, val in zip(indices, values)}
+
+
+def parse_chunk_strings(chunk_str_list: List[str]) -> List[Chunk]:
+ parsed_chunks = []
+
+ # Define safe globals and disable built-in functions for safety.
+ safe_globals = {
+ 'Chunk': Chunk,
+ 'Content': Content,
+ 'Embedding': Embedding,
+ 'defaultdict': defaultdict,
+ 'list': list,
+ '__builtins__': {}
+ }
+
+ for raw_str in chunk_str_list:
+ try:
+ # replace "<class 'list'>" with actual list reference.
+ cleaned_str = re.sub(
+ r"defaultdict\(<class 'list'>", "defaultdict(list", raw_str)
+
+ # Evaluate string in restricted environment.
+ chunk = eval(cleaned_str, safe_globals) # pylint: disable=eval-used
+ if isinstance(chunk, Chunk):
+ parsed_chunks.append(chunk)
+ else:
+ raise ValueError("Parsed object is not a Chunk instance")
+ except Exception as e:
+ raise ValueError(f"Error parsing string:\n{raw_str}\n{e}")
+
+ return parsed_chunks
+
+
+def unpack_dataclass_with_kwargs(dataclass_instance):
+ """Unpacks dataclass fields into a flat dict, merging kwargs with precedence.
+
+ Args:
+ dataclass_instance: Dataclass instance to unpack.
+
+ Returns:
+ dict: Flattened dictionary with kwargs taking precedence over fields.
+ """
+ # Create a copy of the dataclass's __dict__.
+ params_dict: dict = dataclass_instance.__dict__.copy()
+
+ # Extract the nested kwargs dictionary.
+ nested_kwargs = params_dict.pop('kwargs', {})
+
+ # Merge the dictionaries, with nested_kwargs taking precedence
+ # in case of duplicate keys.
+ return {**params_dict, **nested_kwargs}
+
+
+def retry_with_backoff(
+ operation: Callable[[], Any],
+ max_retries: int = 3,
+ retry_delay: float = 1.0,
+ retry_backoff_factor: float = 2.0,
+ operation_name: str = "operation",
+ exception_types: Tuple[Type[BaseException], ...] = (Exception, )
+) -> Any:
+ """Executes an operation with retry logic and exponential backoff.
+
+ This is a generic retry utility that can be used for any operation that may
+ fail transiently. It retries the operation with exponential backoff between
+ attempts.
+
+ Note:
+ This utility is designed for one-time setup operations and complements
+ Apache Beam's RequestResponseIO pattern. Use retry_with_backoff() for:
+
+ * Establishing client connections in __enter__() methods (e.g., creating
+ MilvusClient instances, database connections) before processing elements
+ * One-time setup/teardown operations in DoFn lifecycle methods
+ * Operations outside of per-element processing where retry is needed
+
+ For per-element operations (e.g., API calls within Caller.__call__),
+ use RequestResponseIO which already provides automatic retry with
+ exponential backoff, failure handling, caching, and other features.
+ See: https://beam.apache.org/documentation/io/built-in/webapis/
+
+ Args:
+ operation: Callable that performs the operation to retry. Should return
+ the result of the operation.
+ max_retries: Maximum number of retry attempts. Default is 3.
+ retry_delay: Initial delay in seconds between retries. Default is 1.0.
+ retry_backoff_factor: Multiplier for the delay after each retry. Default
+ is 2.0 (exponential backoff).
+ operation_name: Name of the operation for logging purposes. Default is
+ "operation".
+ exception_types: Tuple of exception types to catch and retry. Default is
+ (Exception,) which catches all exceptions.
+
+ Returns:
+ The result of the operation if successful.
+
+ Raises:
+ The last exception encountered if all retry attempts fail.
+
+ Example:
+ >>> def connect_to_service():
+ ... return service.connect(host="localhost")
+ >>> client = retry_with_backoff(
+ ... connect_to_service,
+ ... max_retries=5,
+ ... retry_delay=2.0,
+ ... operation_name="service connection")
+ """
+ last_exception = None
+ for attempt in range(max_retries + 1):
+ try:
+ result = operation()
+ _LOGGER.info(
+ "Successfully completed %s on attempt %d",
+ operation_name,
+ attempt + 1)
+ return result
+ except exception_types as e:
+ last_exception = e
+ if attempt < max_retries:
+ delay = retry_delay * (retry_backoff_factor**attempt)
+ _LOGGER.warning(
+ "%s attempt %d failed: %s. Retrying in %.2f seconds...",
+ operation_name,
+ attempt + 1,
+ e,
+ delay)
+ time.sleep(delay)
+ else:
+ _LOGGER.error(
+ "Failed %s after %d attempts", operation_name, max_retries + 1)
+ raise last_exception