gemini-code-assist[bot] commented on code in PR #36729: URL: https://github.com/apache/beam/pull/36729#discussion_r2539805004
########## sdks/python/apache_beam/ml/rag/ingestion/milvus_search.py: ########## @@ -0,0 +1,359 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from dataclasses import dataclass +from dataclasses import field +from typing import Any +from typing import Callable +from typing import Dict +from typing import List +from typing import Optional + +from pymilvus import MilvusClient +from pymilvus.exceptions import MilvusException + +import apache_beam as beam +from apache_beam.ml.rag.ingestion.base import VectorDatabaseWriteConfig +from apache_beam.ml.rag.ingestion.jdbc_common import WriteConfig +from apache_beam.ml.rag.ingestion.postgres_common import ColumnSpec +from apache_beam.ml.rag.ingestion.postgres_common import ColumnSpecsBuilder +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.utils import DEFAULT_WRITE_BATCH_SIZE +from apache_beam.ml.rag.utils import MilvusConnectionParameters +from apache_beam.ml.rag.utils import MilvusHelpers +from apache_beam.ml.rag.utils import retry_with_backoff +from apache_beam.ml.rag.utils import unpack_dataclass_with_kwargs +from apache_beam.transforms import DoFn + +_LOGGER = logging.getLogger(__name__) + + +@dataclass +class MilvusWriteConfig: + """Configuration parameters for writing data to Milvus collections. + + This class defines the parameters needed to write data to a Milvus collection, + including collection targeting, batching behavior, and operation timeouts. + + Args: + collection_name: Name of the target Milvus collection to write data to. + Must be a non-empty string. + partition_name: Name of the specific partition within the collection to + write to. If empty, writes to the default partition. + timeout: Maximum time in seconds to wait for write operations to complete. + If None, uses the client's default timeout. + write_config: Configuration for write operations including batch size and + other write-specific settings. + kwargs: Additional keyword arguments for write operations. Enables forward + compatibility with future Milvus client parameters. + """ + collection_name: str + partition_name: str = "" + timeout: Optional[float] = None + write_config: WriteConfig = field(default_factory=WriteConfig) + kwargs: Dict[str, Any] = field(default_factory=dict) + + def __post_init__(self): + if not self.collection_name: + raise ValueError("Collection name must be provided") + + @property + def write_batch_size(self): + """Returns the batch size for write operations. + + Returns: + The configured batch size, or DEFAULT_WRITE_BATCH_SIZE if not specified. + """ + return self.write_config.write_batch_size or DEFAULT_WRITE_BATCH_SIZE + + +@dataclass +class MilvusVectorWriterConfig(VectorDatabaseWriteConfig): + """Configuration for writing vector data to Milvus collections. + + This class extends VectorDatabaseWriteConfig to provide Milvus-specific + configuration for ingesting vector embeddings and associated metadata. + It defines how Apache Beam chunks are converted to Milvus records and + handles the write operation parameters. + + The configuration includes connection parameters, write settings, and + column specifications that determine how chunk data is mapped to Milvus + fields. + + Args: + connection_params: Configuration for connecting to the Milvus server, + including URI, credentials, and connection options. + write_config: Configuration for write operations including collection name, + partition, batch size, and timeouts. + column_specs: List of column specifications defining how chunk fields are + mapped to Milvus collection fields. Defaults to standard RAG fields + (id, embedding, sparse_embedding, content, metadata). + + Example: + config = MilvusVectorWriterConfig( + connection_params=MilvusConnectionParameters( + uri="http://localhost:19530"), + write_config=MilvusWriteConfig(collection_name="my_collection"), + column_specs=MilvusVectorWriterConfig.default_column_specs()) + """ + connection_params: MilvusConnectionParameters + write_config: MilvusWriteConfig + column_specs: List[ColumnSpec] = field( + default_factory=lambda: MilvusVectorWriterConfig.default_column_specs()) + + def create_converter(self) -> Callable[[Chunk], Dict[str, Any]]: + """Creates a function to convert Apache Beam Chunks to Milvus records. + + Returns: + A function that takes a Chunk and returns a dictionary representing + a Milvus record with fields mapped according to column_specs. + """ + def convert(chunk: Chunk) -> Dict[str, Any]: + result = {} + for col in self.column_specs: + result[col.column_name] = col.value_fn(chunk) + return result + + return convert + + def create_write_transform(self) -> beam.PTransform: + """Creates the Apache Beam transform for writing to Milvus. + + Returns: + A PTransform that can be applied to a PCollection of Chunks to write + them to the configured Milvus collection. + """ + return _WriteToMilvusVectorDatabase(self) + + @staticmethod + def default_column_specs() -> List[ColumnSpec]: + """Returns default column specifications for RAG use cases. + + Creates column mappings for standard RAG fields: id, dense embedding, + sparse embedding, content text, and metadata. These specifications + define how Chunk fields are converted to Milvus-compatible formats. + + Returns: + List of ColumnSpec objects defining the default field mappings. + """ + column_specs = ColumnSpecsBuilder() + return column_specs\ + .with_id_spec()\ + .with_embedding_spec(convert_fn=lambda values: list(values))\ + .with_sparse_embedding_spec(conv_fn=MilvusHelpers.sparse_embedding)\ + .with_content_spec()\ + .with_metadata_spec(convert_fn=lambda values: dict(values))\ + .build() + + +class _WriteToMilvusVectorDatabase(beam.PTransform): + """Apache Beam PTransform for writing vector data to Milvus. + + This transform handles the conversion of Apache Beam Chunks to Milvus records + and coordinates the write operations. It applies the configured converter + function and uses a DoFn for batched writes to optimize performance. + + Args: + config: MilvusVectorWriterConfig containing all necessary parameters for + the write operation. + """ + def __init__(self, config: MilvusVectorWriterConfig): + self.config = config + + def expand(self, pcoll: beam.PCollection[Chunk]): + """Expands the PTransform to convert chunks and write to Milvus. + + Args: + pcoll: PCollection of Chunk objects to write to Milvus. + + Returns: + PCollection of the same Chunk objects after writing to Milvus. Review Comment:  The docstring incorrectly states that this transform returns a `PCollection` of `Chunk` objects. The transform chain converts `Chunk` objects into dictionaries and then passes them to `_WriteMilvusFn`, which yields these dictionaries. The return type should be documented as a `PCollection` of dictionaries. ```suggestion PCollection of dictionaries representing the records written to Milvus. ``` ########## sdks/python/apache_beam/ml/rag/ingestion/milvus_search.py: ########## @@ -0,0 +1,359 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from dataclasses import dataclass +from dataclasses import field +from typing import Any +from typing import Callable +from typing import Dict +from typing import List +from typing import Optional + +from pymilvus import MilvusClient +from pymilvus.exceptions import MilvusException + +import apache_beam as beam +from apache_beam.ml.rag.ingestion.base import VectorDatabaseWriteConfig +from apache_beam.ml.rag.ingestion.jdbc_common import WriteConfig +from apache_beam.ml.rag.ingestion.postgres_common import ColumnSpec +from apache_beam.ml.rag.ingestion.postgres_common import ColumnSpecsBuilder +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.utils import DEFAULT_WRITE_BATCH_SIZE +from apache_beam.ml.rag.utils import MilvusConnectionParameters +from apache_beam.ml.rag.utils import MilvusHelpers +from apache_beam.ml.rag.utils import retry_with_backoff +from apache_beam.ml.rag.utils import unpack_dataclass_with_kwargs +from apache_beam.transforms import DoFn + +_LOGGER = logging.getLogger(__name__) + + +@dataclass +class MilvusWriteConfig: + """Configuration parameters for writing data to Milvus collections. + + This class defines the parameters needed to write data to a Milvus collection, + including collection targeting, batching behavior, and operation timeouts. + + Args: + collection_name: Name of the target Milvus collection to write data to. + Must be a non-empty string. + partition_name: Name of the specific partition within the collection to + write to. If empty, writes to the default partition. + timeout: Maximum time in seconds to wait for write operations to complete. + If None, uses the client's default timeout. + write_config: Configuration for write operations including batch size and + other write-specific settings. + kwargs: Additional keyword arguments for write operations. Enables forward + compatibility with future Milvus client parameters. + """ + collection_name: str + partition_name: str = "" + timeout: Optional[float] = None + write_config: WriteConfig = field(default_factory=WriteConfig) + kwargs: Dict[str, Any] = field(default_factory=dict) + + def __post_init__(self): + if not self.collection_name: + raise ValueError("Collection name must be provided") + + @property + def write_batch_size(self): + """Returns the batch size for write operations. + + Returns: + The configured batch size, or DEFAULT_WRITE_BATCH_SIZE if not specified. + """ + return self.write_config.write_batch_size or DEFAULT_WRITE_BATCH_SIZE + + +@dataclass +class MilvusVectorWriterConfig(VectorDatabaseWriteConfig): + """Configuration for writing vector data to Milvus collections. + + This class extends VectorDatabaseWriteConfig to provide Milvus-specific + configuration for ingesting vector embeddings and associated metadata. + It defines how Apache Beam chunks are converted to Milvus records and + handles the write operation parameters. + + The configuration includes connection parameters, write settings, and + column specifications that determine how chunk data is mapped to Milvus + fields. + + Args: + connection_params: Configuration for connecting to the Milvus server, + including URI, credentials, and connection options. + write_config: Configuration for write operations including collection name, + partition, batch size, and timeouts. + column_specs: List of column specifications defining how chunk fields are + mapped to Milvus collection fields. Defaults to standard RAG fields + (id, embedding, sparse_embedding, content, metadata). + + Example: + config = MilvusVectorWriterConfig( + connection_params=MilvusConnectionParameters( + uri="http://localhost:19530"), + write_config=MilvusWriteConfig(collection_name="my_collection"), + column_specs=MilvusVectorWriterConfig.default_column_specs()) + """ + connection_params: MilvusConnectionParameters + write_config: MilvusWriteConfig + column_specs: List[ColumnSpec] = field( + default_factory=lambda: MilvusVectorWriterConfig.default_column_specs()) + + def create_converter(self) -> Callable[[Chunk], Dict[str, Any]]: + """Creates a function to convert Apache Beam Chunks to Milvus records. + + Returns: + A function that takes a Chunk and returns a dictionary representing + a Milvus record with fields mapped according to column_specs. + """ + def convert(chunk: Chunk) -> Dict[str, Any]: + result = {} + for col in self.column_specs: + result[col.column_name] = col.value_fn(chunk) + return result + + return convert + + def create_write_transform(self) -> beam.PTransform: + """Creates the Apache Beam transform for writing to Milvus. + + Returns: + A PTransform that can be applied to a PCollection of Chunks to write + them to the configured Milvus collection. + """ + return _WriteToMilvusVectorDatabase(self) + + @staticmethod + def default_column_specs() -> List[ColumnSpec]: + """Returns default column specifications for RAG use cases. + + Creates column mappings for standard RAG fields: id, dense embedding, + sparse embedding, content text, and metadata. These specifications + define how Chunk fields are converted to Milvus-compatible formats. + + Returns: + List of ColumnSpec objects defining the default field mappings. + """ + column_specs = ColumnSpecsBuilder() + return column_specs\ + .with_id_spec()\ + .with_embedding_spec(convert_fn=lambda values: list(values))\ + .with_sparse_embedding_spec(conv_fn=MilvusHelpers.sparse_embedding)\ + .with_content_spec()\ + .with_metadata_spec(convert_fn=lambda values: dict(values))\ + .build() + + +class _WriteToMilvusVectorDatabase(beam.PTransform): + """Apache Beam PTransform for writing vector data to Milvus. + + This transform handles the conversion of Apache Beam Chunks to Milvus records + and coordinates the write operations. It applies the configured converter + function and uses a DoFn for batched writes to optimize performance. + + Args: + config: MilvusVectorWriterConfig containing all necessary parameters for + the write operation. + """ + def __init__(self, config: MilvusVectorWriterConfig): + self.config = config + + def expand(self, pcoll: beam.PCollection[Chunk]): + """Expands the PTransform to convert chunks and write to Milvus. + + Args: + pcoll: PCollection of Chunk objects to write to Milvus. + + Returns: + PCollection of the same Chunk objects after writing to Milvus. + """ + return ( + pcoll + | "Convert to Records" >> beam.Map(self.config.create_converter()) + | beam.ParDo( + _WriteMilvusFn( + self.config.connection_params, self.config.write_config))) + + +class _WriteMilvusFn(DoFn): + """DoFn that handles batched writes to Milvus. + + This DoFn accumulates records in batches and flushes them to Milvus when + the batch size is reached or when the bundle finishes. This approach + optimizes performance by reducing the number of individual write operations. + + Args: + connection_params: Configuration for connecting to the Milvus server. + write_config: Configuration for write operations including batch size + and collection details. + """ + def __init__( + self, + connection_params: MilvusConnectionParameters, + write_config: MilvusWriteConfig): + self._connection_params = connection_params + self._write_config = write_config + self.batch = [] + + def process(self, element, *args, **kwargs): + """Processes individual records, batching them for efficient writes. + + Args: + element: A dictionary representing a Milvus record to write. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Yields: + The original element after adding it to the batch. + """ + _ = args, kwargs # Unused parameters + self.batch.append(element) + if len(self.batch) >= self._write_config.write_batch_size: + self._flush() + yield element + + def finish_bundle(self): + """Called when a bundle finishes processing. + + Flushes any remaining records in the batch to ensure all data is written. + """ + self._flush() + + def _flush(self): + """Flushes the current batch of records to Milvus. + + Creates a MilvusSink connection and writes all batched records, + then clears the batch for the next set of records. + """ + if len(self.batch) == 0: + return + with _MilvusSink(self._connection_params, self._write_config) as sink: + sink.write(self.batch) + self.batch = [] + + def display_data(self): + """Returns display data for monitoring and debugging. + + Returns: + Dictionary containing database, collection, and batch size information + for display in the Apache Beam monitoring UI. + """ + res = super().display_data() + res["database"] = self._connection_params.db_name + res["collection"] = self._write_config.collection_name + res["batch_size"] = self._write_config.write_batch_size + return res + + +class _MilvusSink: + """Low-level sink for writing data directly to Milvus. + + This class handles the direct interaction with the Milvus client for + upsert operations. It manages the connection lifecycle and provides + context manager support for proper resource cleanup. + + Args: + connection_params: Configuration for connecting to the Milvus server. + write_config: Configuration for write operations including collection + and partition targeting. + """ + def __init__( + self, + connection_params: MilvusConnectionParameters, + write_config: MilvusWriteConfig): + self._connection_params = connection_params + self._write_config = write_config + self._client = None + + def write(self, documents): + """Writes a batch of documents to the Milvus collection. + + Performs an upsert operation to insert new documents or update existing + ones based on primary key. After the upsert, flushes the collection to + ensure data persistence. + + Args: + documents: List of dictionaries representing Milvus records to write. + Each dictionary should contain fields matching the collection schema. + """ + if not self._client: + self._client = MilvusClient( + **unpack_dataclass_with_kwargs(self._connection_params)) Review Comment:  The `_MilvusSink` is used as a context manager, which guarantees that `__enter__` is called to initialize `self._client` with retry logic. This explicit check and client initialization in the `write` method is redundant and bypasses the robust connection logic in `__enter__`. It's safer to rely on the context manager to handle client setup. ########## sdks/python/apache_beam/ml/rag/ingestion/milvus_search_it_test.py: ########## @@ -0,0 +1,642 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import platform +import unittest +import uuid +from typing import Callable +from typing import cast + +import pytest +from pymilvus import CollectionSchema +from pymilvus import DataType +from pymilvus import FieldSchema +from pymilvus import MilvusClient +from pymilvus.exceptions import MilvusException +from pymilvus.milvus_client import IndexParams + +import apache_beam as beam +from apache_beam.ml.rag.ingestion.jdbc_common import WriteConfig +from apache_beam.ml.rag.test_utils import MilvusTestHelpers +from apache_beam.ml.rag.test_utils import VectorDBContainerInfo +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import Content +from apache_beam.ml.rag.types import Embedding +from apache_beam.ml.rag.utils import MilvusConnectionParameters +from apache_beam.ml.rag.utils import retry_with_backoff +from apache_beam.ml.rag.utils import unpack_dataclass_with_kwargs +from apache_beam.testing.test_pipeline import TestPipeline + +try: + from apache_beam.ml.rag.ingestion.milvus_search import MilvusVectorWriterConfig + from apache_beam.ml.rag.ingestion.milvus_search import MilvusWriteConfig +except ImportError as e: + raise unittest.SkipTest(f'Milvus dependencies not installed: {str(e)}') + + +def _construct_index_params(): + index_params = IndexParams() + + # Dense vector index for dense embeddings. + index_params.add_index( + field_name="embedding", + index_name="embedding_ivf_flat", + index_type="IVF_FLAT", + metric_type="COSINE", + params={"nlist": 1}) + + # Sparse vector index for sparse embeddings. + index_params.add_index( + field_name="sparse_embedding", + index_name="sparse_embedding_inverted_index", + index_type="SPARSE_INVERTED_INDEX", + metric_type="IP", + params={"inverted_index_algo": "TAAT_NAIVE"}) + + return index_params + + +MILVUS_INGESTION_IT_CONFIG = { + "fields": [ + FieldSchema( + name="id", dtype=DataType.INT64, is_primary=True, auto_id=False), + FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=1000), + FieldSchema(name="metadata", dtype=DataType.JSON), + FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=3), + FieldSchema( + name="sparse_embedding", dtype=DataType.SPARSE_FLOAT_VECTOR) + ], + "index": _construct_index_params, + "corpus": [ + Chunk( + id=1, # type: ignore[arg-type] + content=Content(text="Test document one"), + metadata={"source": "test1"}, + embedding=Embedding( + dense_embedding=[0.1, 0.2, 0.3], + sparse_embedding=([1, 2], [0.1, 0.2])), + ), + Chunk( + id=2, # type: ignore[arg-type] + content=Content(text="Test document two"), + metadata={"source": "test2"}, + embedding=Embedding( + dense_embedding=[0.2, 0.3, 0.4], + sparse_embedding=([2, 3], [0.3, 0.1]), + ), + ), + Chunk( + id=3, # type: ignore[arg-type] + content=Content(text="Test document three"), + metadata={"source": "test3"}, + embedding=Embedding( + dense_embedding=[0.3, 0.4, 0.5], + sparse_embedding=([3, 4], [0.4, 0.2]), + ), + ) + ] +} + + +def create_collection_with_partition( + client: MilvusClient, + collection_name: str, + partition_name: str = '', + fields=None): + + if fields is None: + fields = MILVUS_INGESTION_IT_CONFIG["fields"] + + # Configure schema. + schema = CollectionSchema(fields=fields) + + # Configure index. + index_function: Callable[[], IndexParams] = cast( + Callable[[], IndexParams], MILVUS_INGESTION_IT_CONFIG["index"]) + + # Create collection with schema. + client.create_collection( + collection_name=collection_name, + schema=schema, + index_params=index_function()) + + # Create partition within the collection. + client.create_partition( + collection_name=collection_name, partition_name=partition_name) + + msg = f"Expected collection '{collection_name}' to be created." + assert client.has_collection(collection_name), msg + + msg = f"Expected partition '{partition_name}' to be created." + assert client.has_partition(collection_name, partition_name), msg + + # Release the collection from memory. We don't need that on pure writing. + client.release_collection(collection_name) + + +def drop_collection(client: MilvusClient, collection_name: str): + try: + client.drop_collection(collection_name) + assert not client.has_collection(collection_name) + except Exception: + # Silently ignore connection errors during cleanup. + pass + + [email protected]_docker_in_docker [email protected]( + platform.system() == "Linux", + "Test runs only on Linux due to lack of support, as yet, for nested " + "virtualization in CI environments on Windows/macOS. Many CI providers run " + "tests in virtualized environments, and nested virtualization " + "(Docker inside a VM) is either unavailable or has several issues on " + "non-Linux platforms.") +class TestMilvusVectorWriterConfig(unittest.TestCase): + """Integration tests for Milvus vector database ingestion functionality""" + + _db: VectorDBContainerInfo + + @classmethod + def setUpClass(cls): + cls._db = MilvusTestHelpers.start_db_container() + cls._connection_config = MilvusConnectionParameters( + uri=cls._db.uri, + user=cls._db.user, + password=cls._db.password, + db_name=cls._db.id, + token=cls._db.token) + + @classmethod + def tearDownClass(cls): + MilvusTestHelpers.stop_db_container(cls._db) + cls._db = None + + def setUp(self): + self.write_test_pipeline = TestPipeline() + self.write_test_pipeline.not_use_test_runner_api = True + self._collection_name = f"test_collection_{self._testMethodName}" + self._partition_name = f"test_partition_{self._testMethodName}" + config = unpack_dataclass_with_kwargs(self._connection_config) + config["alias"] = f"milvus_conn_{uuid.uuid4().hex[:8]}" + + # Use retry_with_backoff for test client connection. + def create_client(): + return MilvusClient(**config) + + self._test_client = retry_with_backoff( + create_client, + max_retries=3, + retry_delay=1.0, + operation_name="Test Milvus client connection", + exception_types=(MilvusException, )) + + create_collection_with_partition( + self._test_client, self._collection_name, self._partition_name) + + def tearDown(self): + drop_collection(self._test_client, self._collection_name) + self._test_client.close() + + def test_invalid_write_on_non_existent_collection(self): + non_existent_collection = "nonexistent_collection" + + test_chunks = MILVUS_INGESTION_IT_CONFIG["corpus"] + + write_config = MilvusWriteConfig( + collection_name=non_existent_collection, + write_config=WriteConfig(write_batch_size=1)) + config = MilvusVectorWriterConfig( + connection_params=self._connection_config, + write_config=write_config, + ) + + # Write pipeline. + with self.assertRaises(Exception) as context: + with TestPipeline() as p: + _ = (p | beam.Create(test_chunks) | config.create_write_transform()) + + # Assert on what should happen. + self.assertIn("can't find collection", str(context.exception).lower()) + + def test_invalid_write_on_non_existent_partition(self): + non_existent_partition = "nonexistent_partition" + + test_chunks = MILVUS_INGESTION_IT_CONFIG["corpus"] + + write_config = MilvusWriteConfig( + collection_name=self._collection_name, + partition_name=non_existent_partition, + write_config=WriteConfig(write_batch_size=1)) + config = MilvusVectorWriterConfig( + connection_params=self._connection_config, write_config=write_config) + + # Write pipeline. + with self.assertRaises(Exception) as context: + with TestPipeline() as p: + _ = (p | beam.Create(test_chunks) | config.create_write_transform()) + + # Assert on what should happen. + self.assertIn("partition not found", str(context.exception).lower()) + + def test_invalid_write_on_missing_primary_key_in_entity(self): + test_chunks = [ + Chunk( + content=Content(text="Test content without ID"), + embedding=Embedding( + dense_embedding=[0.1, 0.2, 0.3], + sparse_embedding=([1, 2], [0.1, 0.2])), + metadata={"source": "test"}) + ] + + write_config = MilvusWriteConfig( + collection_name=self._collection_name, + partition_name=self._partition_name, + write_config=WriteConfig(write_batch_size=1)) + + # Deliberately remove id primary key from the entity. + specs = MilvusVectorWriterConfig.default_column_specs() + for i, spec in enumerate(specs): + if spec.column_name == "id": + del specs[i] + break + + config = MilvusVectorWriterConfig( + connection_params=self._connection_config, + write_config=write_config, + column_specs=specs) + + # Write pipeline. + with self.assertRaises(Exception) as context: + with TestPipeline() as p: + _ = (p | beam.Create(test_chunks) | config.create_write_transform()) + + # Assert on what should happen. + self.assertIn( + "insert missed an field `id` to collection", + str(context.exception).lower()) + + def test_write_on_auto_id_primary_key(self): + auto_id_collection = f"auto_id_collection_{self._testMethodName}" + auto_id_partition = f"auto_id_partition_{self._testMethodName}" + auto_id_fields = [ + FieldSchema( + name="id", dtype=DataType.INT64, is_primary=True, auto_id=True), + FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=1000), + FieldSchema(name="metadata", dtype=DataType.JSON), + FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=3), + FieldSchema( + name="sparse_embedding", dtype=DataType.SPARSE_FLOAT_VECTOR) + ] + + # Create collection with an auto id field. + create_collection_with_partition( + client=self._test_client, + collection_name=auto_id_collection, + partition_name=auto_id_partition, + fields=auto_id_fields) + + test_chunks = [ + Chunk( + id=1, + content=Content(text="Test content without ID"), + embedding=Embedding( + dense_embedding=[0.1, 0.2, 0.3], + sparse_embedding=([1, 2], [0.1, 0.2])), + metadata={"source": "test"}) + ] + + write_config = MilvusWriteConfig( + collection_name=auto_id_collection, + partition_name=auto_id_partition, + write_config=WriteConfig(write_batch_size=1)) + + config = MilvusVectorWriterConfig( + connection_params=self._connection_config, write_config=write_config) + + with self.write_test_pipeline as p: + _ = (p | beam.Create(test_chunks) | config.create_write_transform()) + + self._test_client.flush(auto_id_collection) + self._test_client.load_collection(auto_id_collection) + result = self._test_client.query( + collection_name=auto_id_collection, + partition_names=[auto_id_partition], + limit=3) + + # Test there is only one item in the result and the ID is not equal to one. + self.assertEqual(len(result), len(test_chunks)) + result_item = dict(result[0]) + self.assertNotEqual(result_item["id"], 1) + + def test_write_on_existent_collection_with_default_schema(self): + test_chunks = MILVUS_INGESTION_IT_CONFIG["corpus"] + + write_config = MilvusWriteConfig( + collection_name=self._collection_name, + partition_name=self._partition_name, + write_config=WriteConfig(write_batch_size=3)) + config = MilvusVectorWriterConfig( + connection_params=self._connection_config, write_config=write_config) + + with self.write_test_pipeline as p: + _ = (p | beam.Create(test_chunks) | config.create_write_transform()) + + # Verify data was written successfully. + self._test_client.flush(self._collection_name) + self._test_client.load_collection(self._collection_name) + result = self._test_client.query( + collection_name=self._collection_name, + partition_names=[self._partition_name], + limit=10) + + self.assertEqual(len(result), len(test_chunks)) + + # Verify each chunk was written correctly. + result_by_id = {item["id"]: item for item in result} + for chunk in test_chunks: + self.assertIn(chunk.id, result_by_id) + result_item = result_by_id[chunk.id] + self.assertEqual( + result_item["content"], + chunk.content.text + if hasattr(chunk.content, 'text') else chunk.content) Review Comment:  The `hasattr(chunk.content, 'text')` check is confusing. According to the `Chunk` type definition, `chunk.content` is always a `Content` object, which has a `text` attribute. The `else chunk.content` branch would compare a string with a `Content` object, which is likely not the intention. The assertion can be simplified. ```python self.assertEqual(result_item["content"], chunk.content.text) ``` ########## CHANGES.md: ########## @@ -75,6 +75,9 @@ * X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). * Python examples added for Milvus search enrichment handler on [Beam Website](https://beam.apache.org/documentation/transforms/python/elementwise/enrichment-milvus/) including jupyter notebook example (Python) ([#36176](https://github.com/apache/beam/issues/36176)). +* Milvus sink I/O connector added (Python) ([#36702]( + https://github.com/apache/beam/issues/36702)). Now Beam has full support for + Milvus integration including Milvus enrichment and sink operations. Review Comment:  For better readability of the raw markdown file, consider keeping the link and its description on a single line. ```suggestion * Milvus sink I/O connector added (Python) ([#36702](https://github.com/apache/beam/issues/36702)). Now Beam has full support for Milvus integration including Milvus enrichment and sink operations. ``` ########## sdks/python/apache_beam/ml/rag/ingestion/milvus_search_it_test.py: ########## @@ -0,0 +1,642 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import platform +import unittest +import uuid +from typing import Callable +from typing import cast + +import pytest +from pymilvus import CollectionSchema +from pymilvus import DataType +from pymilvus import FieldSchema +from pymilvus import MilvusClient +from pymilvus.exceptions import MilvusException +from pymilvus.milvus_client import IndexParams + +import apache_beam as beam +from apache_beam.ml.rag.ingestion.jdbc_common import WriteConfig +from apache_beam.ml.rag.test_utils import MilvusTestHelpers +from apache_beam.ml.rag.test_utils import VectorDBContainerInfo +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import Content +from apache_beam.ml.rag.types import Embedding +from apache_beam.ml.rag.utils import MilvusConnectionParameters +from apache_beam.ml.rag.utils import retry_with_backoff +from apache_beam.ml.rag.utils import unpack_dataclass_with_kwargs +from apache_beam.testing.test_pipeline import TestPipeline + +try: + from apache_beam.ml.rag.ingestion.milvus_search import MilvusVectorWriterConfig + from apache_beam.ml.rag.ingestion.milvus_search import MilvusWriteConfig +except ImportError as e: + raise unittest.SkipTest(f'Milvus dependencies not installed: {str(e)}') + + +def _construct_index_params(): + index_params = IndexParams() + + # Dense vector index for dense embeddings. + index_params.add_index( + field_name="embedding", + index_name="embedding_ivf_flat", + index_type="IVF_FLAT", + metric_type="COSINE", + params={"nlist": 1}) + + # Sparse vector index for sparse embeddings. + index_params.add_index( + field_name="sparse_embedding", + index_name="sparse_embedding_inverted_index", + index_type="SPARSE_INVERTED_INDEX", + metric_type="IP", + params={"inverted_index_algo": "TAAT_NAIVE"}) + + return index_params + + +MILVUS_INGESTION_IT_CONFIG = { + "fields": [ + FieldSchema( + name="id", dtype=DataType.INT64, is_primary=True, auto_id=False), + FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=1000), + FieldSchema(name="metadata", dtype=DataType.JSON), + FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=3), + FieldSchema( + name="sparse_embedding", dtype=DataType.SPARSE_FLOAT_VECTOR) + ], + "index": _construct_index_params, + "corpus": [ + Chunk( + id=1, # type: ignore[arg-type] + content=Content(text="Test document one"), + metadata={"source": "test1"}, + embedding=Embedding( + dense_embedding=[0.1, 0.2, 0.3], + sparse_embedding=([1, 2], [0.1, 0.2])), + ), + Chunk( + id=2, # type: ignore[arg-type] + content=Content(text="Test document two"), + metadata={"source": "test2"}, + embedding=Embedding( + dense_embedding=[0.2, 0.3, 0.4], + sparse_embedding=([2, 3], [0.3, 0.1]), + ), + ), + Chunk( + id=3, # type: ignore[arg-type] + content=Content(text="Test document three"), + metadata={"source": "test3"}, + embedding=Embedding( + dense_embedding=[0.3, 0.4, 0.5], + sparse_embedding=([3, 4], [0.4, 0.2]), + ), + ) + ] +} + + +def create_collection_with_partition( + client: MilvusClient, + collection_name: str, + partition_name: str = '', + fields=None): + + if fields is None: + fields = MILVUS_INGESTION_IT_CONFIG["fields"] + + # Configure schema. + schema = CollectionSchema(fields=fields) + + # Configure index. + index_function: Callable[[], IndexParams] = cast( + Callable[[], IndexParams], MILVUS_INGESTION_IT_CONFIG["index"]) + + # Create collection with schema. + client.create_collection( + collection_name=collection_name, + schema=schema, + index_params=index_function()) + + # Create partition within the collection. + client.create_partition( + collection_name=collection_name, partition_name=partition_name) + + msg = f"Expected collection '{collection_name}' to be created." + assert client.has_collection(collection_name), msg + + msg = f"Expected partition '{partition_name}' to be created." + assert client.has_partition(collection_name, partition_name), msg + + # Release the collection from memory. We don't need that on pure writing. + client.release_collection(collection_name) + + +def drop_collection(client: MilvusClient, collection_name: str): + try: + client.drop_collection(collection_name) + assert not client.has_collection(collection_name) + except Exception: + # Silently ignore connection errors during cleanup. + pass + + [email protected]_docker_in_docker [email protected]( + platform.system() == "Linux", + "Test runs only on Linux due to lack of support, as yet, for nested " + "virtualization in CI environments on Windows/macOS. Many CI providers run " + "tests in virtualized environments, and nested virtualization " + "(Docker inside a VM) is either unavailable or has several issues on " + "non-Linux platforms.") +class TestMilvusVectorWriterConfig(unittest.TestCase): + """Integration tests for Milvus vector database ingestion functionality""" + + _db: VectorDBContainerInfo + + @classmethod + def setUpClass(cls): + cls._db = MilvusTestHelpers.start_db_container() + cls._connection_config = MilvusConnectionParameters( + uri=cls._db.uri, + user=cls._db.user, + password=cls._db.password, + db_name=cls._db.id, + token=cls._db.token) + + @classmethod + def tearDownClass(cls): + MilvusTestHelpers.stop_db_container(cls._db) + cls._db = None + + def setUp(self): + self.write_test_pipeline = TestPipeline() + self.write_test_pipeline.not_use_test_runner_api = True + self._collection_name = f"test_collection_{self._testMethodName}" + self._partition_name = f"test_partition_{self._testMethodName}" + config = unpack_dataclass_with_kwargs(self._connection_config) + config["alias"] = f"milvus_conn_{uuid.uuid4().hex[:8]}" + + # Use retry_with_backoff for test client connection. + def create_client(): + return MilvusClient(**config) + + self._test_client = retry_with_backoff( + create_client, + max_retries=3, + retry_delay=1.0, + operation_name="Test Milvus client connection", + exception_types=(MilvusException, )) + + create_collection_with_partition( + self._test_client, self._collection_name, self._partition_name) + + def tearDown(self): + drop_collection(self._test_client, self._collection_name) + self._test_client.close() + + def test_invalid_write_on_non_existent_collection(self): + non_existent_collection = "nonexistent_collection" + + test_chunks = MILVUS_INGESTION_IT_CONFIG["corpus"] + + write_config = MilvusWriteConfig( + collection_name=non_existent_collection, + write_config=WriteConfig(write_batch_size=1)) + config = MilvusVectorWriterConfig( + connection_params=self._connection_config, + write_config=write_config, + ) + + # Write pipeline. + with self.assertRaises(Exception) as context: + with TestPipeline() as p: + _ = (p | beam.Create(test_chunks) | config.create_write_transform()) + + # Assert on what should happen. + self.assertIn("can't find collection", str(context.exception).lower()) + + def test_invalid_write_on_non_existent_partition(self): + non_existent_partition = "nonexistent_partition" + + test_chunks = MILVUS_INGESTION_IT_CONFIG["corpus"] + + write_config = MilvusWriteConfig( + collection_name=self._collection_name, + partition_name=non_existent_partition, + write_config=WriteConfig(write_batch_size=1)) + config = MilvusVectorWriterConfig( + connection_params=self._connection_config, write_config=write_config) + + # Write pipeline. + with self.assertRaises(Exception) as context: + with TestPipeline() as p: + _ = (p | beam.Create(test_chunks) | config.create_write_transform()) + + # Assert on what should happen. + self.assertIn("partition not found", str(context.exception).lower()) + + def test_invalid_write_on_missing_primary_key_in_entity(self): + test_chunks = [ + Chunk( + content=Content(text="Test content without ID"), + embedding=Embedding( + dense_embedding=[0.1, 0.2, 0.3], + sparse_embedding=([1, 2], [0.1, 0.2])), + metadata={"source": "test"}) + ] + + write_config = MilvusWriteConfig( + collection_name=self._collection_name, + partition_name=self._partition_name, + write_config=WriteConfig(write_batch_size=1)) + + # Deliberately remove id primary key from the entity. + specs = MilvusVectorWriterConfig.default_column_specs() + for i, spec in enumerate(specs): + if spec.column_name == "id": + del specs[i] + break + + config = MilvusVectorWriterConfig( + connection_params=self._connection_config, + write_config=write_config, + column_specs=specs) + + # Write pipeline. + with self.assertRaises(Exception) as context: + with TestPipeline() as p: + _ = (p | beam.Create(test_chunks) | config.create_write_transform()) + + # Assert on what should happen. + self.assertIn( + "insert missed an field `id` to collection", + str(context.exception).lower()) + + def test_write_on_auto_id_primary_key(self): + auto_id_collection = f"auto_id_collection_{self._testMethodName}" + auto_id_partition = f"auto_id_partition_{self._testMethodName}" + auto_id_fields = [ + FieldSchema( + name="id", dtype=DataType.INT64, is_primary=True, auto_id=True), + FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=1000), + FieldSchema(name="metadata", dtype=DataType.JSON), + FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=3), + FieldSchema( + name="sparse_embedding", dtype=DataType.SPARSE_FLOAT_VECTOR) + ] + + # Create collection with an auto id field. + create_collection_with_partition( + client=self._test_client, + collection_name=auto_id_collection, + partition_name=auto_id_partition, + fields=auto_id_fields) + + test_chunks = [ + Chunk( + id=1, + content=Content(text="Test content without ID"), + embedding=Embedding( + dense_embedding=[0.1, 0.2, 0.3], + sparse_embedding=([1, 2], [0.1, 0.2])), + metadata={"source": "test"}) + ] + + write_config = MilvusWriteConfig( + collection_name=auto_id_collection, + partition_name=auto_id_partition, + write_config=WriteConfig(write_batch_size=1)) + + config = MilvusVectorWriterConfig( + connection_params=self._connection_config, write_config=write_config) + + with self.write_test_pipeline as p: + _ = (p | beam.Create(test_chunks) | config.create_write_transform()) + + self._test_client.flush(auto_id_collection) + self._test_client.load_collection(auto_id_collection) + result = self._test_client.query( + collection_name=auto_id_collection, + partition_names=[auto_id_partition], + limit=3) + + # Test there is only one item in the result and the ID is not equal to one. + self.assertEqual(len(result), len(test_chunks)) + result_item = dict(result[0]) + self.assertNotEqual(result_item["id"], 1) + + def test_write_on_existent_collection_with_default_schema(self): + test_chunks = MILVUS_INGESTION_IT_CONFIG["corpus"] + + write_config = MilvusWriteConfig( + collection_name=self._collection_name, + partition_name=self._partition_name, + write_config=WriteConfig(write_batch_size=3)) + config = MilvusVectorWriterConfig( + connection_params=self._connection_config, write_config=write_config) + + with self.write_test_pipeline as p: + _ = (p | beam.Create(test_chunks) | config.create_write_transform()) + + # Verify data was written successfully. + self._test_client.flush(self._collection_name) + self._test_client.load_collection(self._collection_name) + result = self._test_client.query( + collection_name=self._collection_name, + partition_names=[self._partition_name], + limit=10) + + self.assertEqual(len(result), len(test_chunks)) + + # Verify each chunk was written correctly. + result_by_id = {item["id"]: item for item in result} + for chunk in test_chunks: + self.assertIn(chunk.id, result_by_id) + result_item = result_by_id[chunk.id] + self.assertEqual( + result_item["content"], + chunk.content.text + if hasattr(chunk.content, 'text') else chunk.content) + self.assertEqual(result_item["metadata"], chunk.metadata) + + # Verify embedding is present and has correct length. + expected_embedding = chunk.embedding.dense_embedding + actual_embedding = result_item["embedding"] + self.assertIsNotNone(actual_embedding) + self.assertEqual(len(actual_embedding), len(expected_embedding)) + + def test_write_with_custom_column_specifications(self): + from apache_beam.ml.rag.ingestion.postgres_common import ColumnSpec + from apache_beam.ml.rag.utils import MilvusHelpers + + custom_column_specs = [ + ColumnSpec("id", int, lambda chunk: int(chunk.id) if chunk.id else 0), + ColumnSpec( + "content", + str, lambda chunk: ( + chunk.content.text + if hasattr(chunk.content, 'text') else chunk.content)), + ColumnSpec("metadata", dict, lambda chunk: chunk.metadata or {}), + ColumnSpec( + "embedding", + list, lambda chunk: chunk.embedding.dense_embedding or []), + ColumnSpec( + "sparse_embedding", + dict, lambda chunk: ( + MilvusHelpers.sparse_embedding( + chunk.embedding.sparse_embedding) if chunk.embedding and + chunk.embedding.sparse_embedding else {})) + ] + + test_chunks = [ + Chunk( + id=10, + content=Content(text="Custom column spec test"), + embedding=Embedding( + dense_embedding=[0.8, 0.9, 1.0], + sparse_embedding=([1, 3, 5], [0.8, 0.9, 1.0])), + metadata={"custom": "spec_test"}) + ] + + write_config = MilvusWriteConfig( + collection_name=self._collection_name, + partition_name=self._partition_name, + write_config=WriteConfig(write_batch_size=1)) + config = MilvusVectorWriterConfig( + connection_params=self._connection_config, + write_config=write_config, + column_specs=custom_column_specs) + + with self.write_test_pipeline as p: + _ = (p | beam.Create(test_chunks) | config.create_write_transform()) + + # Verify data was written successfully. + self._test_client.flush(self._collection_name) + self._test_client.load_collection(self._collection_name) + result = self._test_client.query( + collection_name=self._collection_name, + partition_names=[self._partition_name], + filter="id == 10", + limit=1) + + self.assertEqual(len(result), 1) + result_item = result[0] + + # Verify custom column specs worked correctly. + self.assertEqual(result_item["id"], 10) + self.assertEqual(result_item["content"], "Custom column spec test") + self.assertEqual(result_item["metadata"], {"custom": "spec_test"}) + + # Verify embedding is present and has correct length. + expected_embedding = [0.8, 0.9, 1.0] + actual_embedding = result_item["embedding"] + self.assertIsNotNone(actual_embedding) + self.assertEqual(len(actual_embedding), len(expected_embedding)) + + # Verify sparse embedding was converted correctly - check keys are present. + expected_sparse_keys = {1, 3, 5} + actual_sparse = result_item["sparse_embedding"] + self.assertIsNotNone(actual_sparse) + self.assertEqual(set(actual_sparse.keys()), expected_sparse_keys) + + def test_write_with_batching(self): + test_chunks = [ + Chunk( + id=i, + content=Content(text=f"Batch test document {i}"), + embedding=Embedding( + dense_embedding=[0.1 * i, 0.2 * i, 0.3 * i], + sparse_embedding=([i, i + 1], [0.1 * i, 0.2 * i])), + metadata={"batch_id": i}) for i in range(1, 8) # 7 chunks + ] + + # Set small batch size to force batching (7 chunks with batch size 3). + batch_write_config = WriteConfig(write_batch_size=3) + write_config = MilvusWriteConfig( + collection_name=self._collection_name, + partition_name=self._partition_name, + write_config=batch_write_config) + config = MilvusVectorWriterConfig( + connection_params=self._connection_config, write_config=write_config) + + with self.write_test_pipeline as p: + _ = (p | beam.Create(test_chunks) | config.create_write_transform()) + + # Verify all data was written successfully. + # Flush to persist all data to disk, then load collection for querying. + self._test_client.flush(self._collection_name) + self._test_client.load_collection(self._collection_name) + + result = self._test_client.query( + collection_name=self._collection_name, + partition_names=[self._partition_name], + limit=10) + + self.assertEqual(len(result), len(test_chunks)) + + # Verify each batch was written correctly. + result_by_id = {item["id"]: item for item in result} Review Comment:  Similar to another comment, the `hasattr(chunk.content, 'text')` check in this lambda is confusing. The `Chunk` type definition specifies that `chunk.content` is a `Content` object, so `chunk.content.text` should always be accessible. The `else chunk.content` part seems to handle a case that shouldn't occur and makes the code harder to understand. It's better to rely on the type contract and simplify the lambda. ```python ColumnSpec("content", str, lambda chunk: chunk.content.text), ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
