mohamedawnallah commented on code in PR #36729: URL: https://github.com/apache/beam/pull/36729#discussion_r2549381202
########## sdks/python/apache_beam/ml/rag/ingestion/milvus_search.py: ########## @@ -0,0 +1,359 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from dataclasses import dataclass +from dataclasses import field +from typing import Any +from typing import Callable +from typing import Dict +from typing import List +from typing import Optional + +from pymilvus import MilvusClient +from pymilvus.exceptions import MilvusException + +import apache_beam as beam +from apache_beam.ml.rag.ingestion.base import VectorDatabaseWriteConfig +from apache_beam.ml.rag.ingestion.jdbc_common import WriteConfig +from apache_beam.ml.rag.ingestion.postgres_common import ColumnSpec +from apache_beam.ml.rag.ingestion.postgres_common import ColumnSpecsBuilder +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.utils import DEFAULT_WRITE_BATCH_SIZE +from apache_beam.ml.rag.utils import MilvusConnectionParameters +from apache_beam.ml.rag.utils import MilvusHelpers +from apache_beam.ml.rag.utils import retry_with_backoff +from apache_beam.ml.rag.utils import unpack_dataclass_with_kwargs +from apache_beam.transforms import DoFn + +_LOGGER = logging.getLogger(__name__) + + +@dataclass +class MilvusWriteConfig: + """Configuration parameters for writing data to Milvus collections. + + This class defines the parameters needed to write data to a Milvus collection, + including collection targeting, batching behavior, and operation timeouts. + + Args: + collection_name: Name of the target Milvus collection to write data to. + Must be a non-empty string. + partition_name: Name of the specific partition within the collection to + write to. If empty, writes to the default partition. + timeout: Maximum time in seconds to wait for write operations to complete. + If None, uses the client's default timeout. + write_config: Configuration for write operations including batch size and + other write-specific settings. + kwargs: Additional keyword arguments for write operations. Enables forward + compatibility with future Milvus client parameters. + """ + collection_name: str + partition_name: str = "" + timeout: Optional[float] = None + write_config: WriteConfig = field(default_factory=WriteConfig) + kwargs: Dict[str, Any] = field(default_factory=dict) + + def __post_init__(self): + if not self.collection_name: + raise ValueError("Collection name must be provided") + + @property + def write_batch_size(self): + """Returns the batch size for write operations. + + Returns: + The configured batch size, or DEFAULT_WRITE_BATCH_SIZE if not specified. + """ + return self.write_config.write_batch_size or DEFAULT_WRITE_BATCH_SIZE + + +@dataclass +class MilvusVectorWriterConfig(VectorDatabaseWriteConfig): + """Configuration for writing vector data to Milvus collections. + + This class extends VectorDatabaseWriteConfig to provide Milvus-specific + configuration for ingesting vector embeddings and associated metadata. + It defines how Apache Beam chunks are converted to Milvus records and + handles the write operation parameters. + + The configuration includes connection parameters, write settings, and + column specifications that determine how chunk data is mapped to Milvus + fields. + + Args: + connection_params: Configuration for connecting to the Milvus server, + including URI, credentials, and connection options. + write_config: Configuration for write operations including collection name, + partition, batch size, and timeouts. + column_specs: List of column specifications defining how chunk fields are + mapped to Milvus collection fields. Defaults to standard RAG fields + (id, embedding, sparse_embedding, content, metadata). + + Example: + config = MilvusVectorWriterConfig( + connection_params=MilvusConnectionParameters( + uri="http://localhost:19530"), + write_config=MilvusWriteConfig(collection_name="my_collection"), + column_specs=MilvusVectorWriterConfig.default_column_specs()) + """ + connection_params: MilvusConnectionParameters + write_config: MilvusWriteConfig + column_specs: List[ColumnSpec] = field( + default_factory=lambda: MilvusVectorWriterConfig.default_column_specs()) + + def create_converter(self) -> Callable[[Chunk], Dict[str, Any]]: + """Creates a function to convert Apache Beam Chunks to Milvus records. + + Returns: + A function that takes a Chunk and returns a dictionary representing + a Milvus record with fields mapped according to column_specs. + """ + def convert(chunk: Chunk) -> Dict[str, Any]: + result = {} + for col in self.column_specs: + result[col.column_name] = col.value_fn(chunk) + return result + + return convert + + def create_write_transform(self) -> beam.PTransform: + """Creates the Apache Beam transform for writing to Milvus. + + Returns: + A PTransform that can be applied to a PCollection of Chunks to write + them to the configured Milvus collection. + """ + return _WriteToMilvusVectorDatabase(self) + + @staticmethod + def default_column_specs() -> List[ColumnSpec]: + """Returns default column specifications for RAG use cases. + + Creates column mappings for standard RAG fields: id, dense embedding, + sparse embedding, content text, and metadata. These specifications + define how Chunk fields are converted to Milvus-compatible formats. + + Returns: + List of ColumnSpec objects defining the default field mappings. + """ + column_specs = ColumnSpecsBuilder() + return column_specs\ + .with_id_spec()\ + .with_embedding_spec(convert_fn=lambda values: list(values))\ + .with_sparse_embedding_spec(conv_fn=MilvusHelpers.sparse_embedding)\ + .with_content_spec()\ + .with_metadata_spec(convert_fn=lambda values: dict(values))\ + .build() + + +class _WriteToMilvusVectorDatabase(beam.PTransform): + """Apache Beam PTransform for writing vector data to Milvus. + + This transform handles the conversion of Apache Beam Chunks to Milvus records + and coordinates the write operations. It applies the configured converter + function and uses a DoFn for batched writes to optimize performance. + + Args: + config: MilvusVectorWriterConfig containing all necessary parameters for + the write operation. + """ + def __init__(self, config: MilvusVectorWriterConfig): + self.config = config + + def expand(self, pcoll: beam.PCollection[Chunk]): + """Expands the PTransform to convert chunks and write to Milvus. + + Args: + pcoll: PCollection of Chunk objects to write to Milvus. + + Returns: + PCollection of the same Chunk objects after writing to Milvus. + """ + return ( + pcoll + | "Convert to Records" >> beam.Map(self.config.create_converter()) + | beam.ParDo( + _WriteMilvusFn( + self.config.connection_params, self.config.write_config))) + + +class _WriteMilvusFn(DoFn): + """DoFn that handles batched writes to Milvus. + + This DoFn accumulates records in batches and flushes them to Milvus when + the batch size is reached or when the bundle finishes. This approach + optimizes performance by reducing the number of individual write operations. + + Args: + connection_params: Configuration for connecting to the Milvus server. + write_config: Configuration for write operations including batch size + and collection details. + """ + def __init__( + self, + connection_params: MilvusConnectionParameters, + write_config: MilvusWriteConfig): + self._connection_params = connection_params + self._write_config = write_config + self.batch = [] + + def process(self, element, *args, **kwargs): + """Processes individual records, batching them for efficient writes. + + Args: + element: A dictionary representing a Milvus record to write. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Yields: + The original element after adding it to the batch. + """ + _ = args, kwargs # Unused parameters + self.batch.append(element) + if len(self.batch) >= self._write_config.write_batch_size: + self._flush() + yield element + + def finish_bundle(self): + """Called when a bundle finishes processing. + + Flushes any remaining records in the batch to ensure all data is written. + """ + self._flush() + + def _flush(self): + """Flushes the current batch of records to Milvus. + + Creates a MilvusSink connection and writes all batched records, + then clears the batch for the next set of records. + """ + if len(self.batch) == 0: + return + with _MilvusSink(self._connection_params, self._write_config) as sink: + sink.write(self.batch) + self.batch = [] + + def display_data(self): + """Returns display data for monitoring and debugging. + + Returns: + Dictionary containing database, collection, and batch size information + for display in the Apache Beam monitoring UI. + """ + res = super().display_data() + res["database"] = self._connection_params.db_name + res["collection"] = self._write_config.collection_name + res["batch_size"] = self._write_config.write_batch_size + return res + + +class _MilvusSink: + """Low-level sink for writing data directly to Milvus. + + This class handles the direct interaction with the Milvus client for + upsert operations. It manages the connection lifecycle and provides + context manager support for proper resource cleanup. + + Args: + connection_params: Configuration for connecting to the Milvus server. + write_config: Configuration for write operations including collection + and partition targeting. + """ + def __init__( + self, + connection_params: MilvusConnectionParameters, + write_config: MilvusWriteConfig): + self._connection_params = connection_params + self._write_config = write_config + self._client = None + + def write(self, documents): + """Writes a batch of documents to the Milvus collection. + + Performs an upsert operation to insert new documents or update existing + ones based on primary key. After the upsert, flushes the collection to + ensure data persistence. + + Args: + documents: List of dictionaries representing Milvus records to write. + Each dictionary should contain fields matching the collection schema. + """ + if not self._client: + self._client = MilvusClient( + **unpack_dataclass_with_kwargs(self._connection_params)) Review Comment: >  > > The `_MilvusSink` is used as a context manager, which guarantees that `__enter__` is called to initialize `self._client` with retry logic. This explicit check and client initialization in the `write` method is redundant and bypasses the robust connection logic in `__enter__`. It's safer to rely on the context manager to handle client setup. Addressed -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
