damccorm commented on code in PR #33364: URL: https://github.com/apache/beam/pull/33364#discussion_r1884153540
########## sdks/python/apache_beam/ml/rag/chunking/langchain.py: ########## @@ -0,0 +1,72 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any +from typing import Dict +from typing import List +from typing import Optional + +import apache_beam as beam +from apache_beam.ml.rag.chunking.base import ChunkIdFn +from apache_beam.ml.rag.chunking.base import ChunkingTransformProvider +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import Content +from langchain.text_splitter import TextSplitter + + +class LangChainChunkingProvider(ChunkingTransformProvider): + def __init__( + self, + text_splitter: TextSplitter, + document_field: str, + metadata_fields: List[str], + chunk_id_fn: Optional[ChunkIdFn] = None): + if not isinstance(text_splitter, TextSplitter): + raise TypeError("text_splitter must be a LangChain TextSplitter") + if not document_field: + raise ValueError("document_field cannot be empty") + super().__init__(chunk_id_fn) + self.text_splitter = text_splitter + self.document_field = document_field + self.metadata_fields = metadata_fields + + def get_text_splitter_transform( + self + ) -> beam.PTransform[beam.PCollection[Dict[str, Any]], + beam.PCollection[Chunk]]: + return "Langchain text split" >> beam.ParDo( + LangChainTextSplitter( + text_splitter=self.text_splitter, + document_field=self.document_field, + metadata_fields=self.metadata_fields)) + + +class LangChainTextSplitter(beam.DoFn): Review Comment: Should this be publicly exposed? ```suggestion class _LangChainTextSplitter(beam.DoFn): ``` ########## sdks/python/apache_beam/ml/rag/chunking/langchain.py: ########## @@ -0,0 +1,72 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any +from typing import Dict +from typing import List +from typing import Optional + +import apache_beam as beam +from apache_beam.ml.rag.chunking.base import ChunkIdFn +from apache_beam.ml.rag.chunking.base import ChunkingTransformProvider +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import Content +from langchain.text_splitter import TextSplitter + + +class LangChainChunkingProvider(ChunkingTransformProvider): Review Comment: I think this should probably be named something like `LangChainChunker` - I don't think we use the provider terminology elsewhere in a user-facing way. ########## sdks/python/apache_beam/ml/rag/types.py: ########## @@ -0,0 +1,74 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Core types for RAG pipelines. +This module contains the core dataclasses used throughout the RAG pipeline +implementation, including Chunk and Embedding types that define the data +contracts between different stages of the pipeline. +""" + +import uuid +from dataclasses import dataclass +from dataclasses import field +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple + + +@dataclass +class Content: + """Container for embeddable content. + """ + text: Optional[str] = None + image_data: Optional[bytes] = None + + +@dataclass +class Chunk: + """Represents a chunk of text with metadata. + + Attributes: + text: The actual content of the chunk + id: Unique identifier for the chunk + index: Index of this chunk within the original document + metadata: Additional metadata about the chunk (e.g., document source) + """ + content: Content + id: str = field(default_factory=lambda: str(uuid.uuid4())) + index: int = 0 + metadata: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class Embedding: Review Comment: Per https://github.com/apache/beam/pull/33313#discussion_r1882884174 I think we are going to lightly modify this (just dropping this so we don't lose track of it) ########## sdks/python/apache_beam/ml/rag/embeddings/huggingface.py: ########## @@ -0,0 +1,76 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""RAG-specific embedding implementations using HuggingFace models.""" + +from typing import Optional + +import apache_beam as beam +from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.rag.embeddings.base import create_rag_adapter +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import Embedding +from apache_beam.ml.transforms.base import EmbeddingsManager +from apache_beam.ml.transforms.base import _TextEmbeddingHandler +from apache_beam.ml.transforms.embeddings.huggingface import SentenceTransformer +from apache_beam.ml.transforms.embeddings.huggingface import _SentenceTransformerModelHandler + + +class HuggingfaceTextEmbeddings(EmbeddingsManager): + """SentenceTransformer embeddings for RAG pipeline. + + Extends EmbeddingsManager to work with RAG-specific types: + - Input: Chunk objects containing text to embed + - Output: Embedding objects containing vector representations + + The adapter automatically: + - Extracts text from Chunk.content.text + - Preserves Chunk.id in Embedding.id + - Copies Chunk.metadata to Embedding.metadata + - Converts model output to Embedding.dense_embedding Review Comment: Could we also follow this convention for parameters for pydocs - https://github.com/apache/beam/blob/a6061feb05363a6175e28f80f684b6d306a42150/sdks/python/apache_beam/ml/transforms/base.py#L247 That will allow them to be picked up as known params in the rendered doc ########## sdks/python/apache_beam/ml/transforms/base.py: ########## @@ -62,36 +68,31 @@ # Output of the apply() method of BaseOperation. OperationOutputT = TypeVar('OperationOutputT') +# Input to the EmbeddingTypeAdapter input_fn +EmbeddingTypeAdapterInputT = TypeVar( + 'EmbeddingTypeAdapterInputT') # e.g., Chunk +# Output of the EmbeddingTypeAdapter output_fn +EmbeddingTypeAdapterOutputT = TypeVar( + 'EmbeddingTypeAdapterOutputT') # e.g., Embedding -def _convert_list_of_dicts_to_dict_of_lists( - list_of_dicts: Sequence[dict[str, Any]]) -> dict[str, list[Any]]: - keys_to_element_list = collections.defaultdict(list) - input_keys = list_of_dicts[0].keys() - for d in list_of_dicts: - if set(d.keys()) != set(input_keys): - extra_keys = set(d.keys()) - set(input_keys) if len( - d.keys()) > len(input_keys) else set(input_keys) - set(d.keys()) - raise RuntimeError( - f'All the dicts in the input data should have the same keys. ' - f'Got: {extra_keys} instead.') - for key, value in d.items(): - keys_to_element_list[key].append(value) - return keys_to_element_list - - -def _convert_dict_of_lists_to_lists_of_dict( - dict_of_lists: dict[str, list[Any]]) -> list[dict[str, Any]]: - batch_length = len(next(iter(dict_of_lists.values()))) - result: list[dict[str, Any]] = [{} for _ in range(batch_length)] - # all the values in the dict_of_lists should have same length - for key, values in dict_of_lists.items(): - assert len(values) == batch_length, ( - "This function expects all the values " - "in the dict_of_lists to have same length." - ) - for i in range(len(values)): - result[i][key] = values[i] - return result + +@dataclass +class EmbeddingTypeAdapter(Generic[EmbeddingTypeAdapterInputT, + EmbeddingTypeAdapterOutputT]): + """Adapts input types to text for embedding and converts output embeddings. + + Args: + input_fn: Function to extract text for embedding from input type + output_fn: Function to create output type from input and embeddings + """ + input_fn: Callable[[Sequence[EmbeddingTypeAdapterInputT]], List[str]] + output_fn: Callable[[Sequence[EmbeddingTypeAdapterInputT], Sequence[Any]], + List[EmbeddingTypeAdapterOutputT]] + + def __reduce__(self): + """Custom serialization that preserves type information during + jsonpickle.""" Review Comment: I don't quite follow this - why do we need this piece? ########## sdks/python/apache_beam/ml/rag/chunking/langchain.py: ########## @@ -0,0 +1,72 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any +from typing import Dict +from typing import List +from typing import Optional + +import apache_beam as beam +from apache_beam.ml.rag.chunking.base import ChunkIdFn +from apache_beam.ml.rag.chunking.base import ChunkingTransformProvider +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import Content +from langchain.text_splitter import TextSplitter + + +class LangChainChunkingProvider(ChunkingTransformProvider): Review Comment: Also, could you add a comment for the pydoc here? ########## sdks/python/apache_beam/ml/transforms/base.py: ########## @@ -182,13 +183,74 @@ def append_transform(self, transform: BaseOperation): """ +def dict_input_fn(columns: Sequence[str], Review Comment: Nit - I think these functions can be private (`_dict_input_fn`) ########## sdks/python/apache_beam/ml/rag/chunking/base.py: ########## @@ -0,0 +1,63 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import abc +import functools +from collections.abc import Callable +from typing import Any +from typing import Dict +from typing import Optional + +import apache_beam as beam +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.transforms.base import MLTransformProvider + +ChunkIdFn = Callable[[Chunk], str] + + +def assign_chunk_id(chunk_id_fn: ChunkIdFn, chunk: Chunk): + chunk.id = chunk_id_fn(chunk) + return chunk + + +class ChunkingTransformProvider(MLTransformProvider): Review Comment: Could we add a docstring here as well? In general, probably all non-test classes should either have a docstring or should be prefixed with an underscore ########## sdks/python/apache_beam/ml/rag/types.py: ########## @@ -0,0 +1,74 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Core types for RAG pipelines. +This module contains the core dataclasses used throughout the RAG pipeline +implementation, including Chunk and Embedding types that define the data +contracts between different stages of the pipeline. +""" + +import uuid +from dataclasses import dataclass +from dataclasses import field +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple + + +@dataclass +class Content: + """Container for embeddable content. + """ + text: Optional[str] = None + image_data: Optional[bytes] = None Review Comment: I'm leaning towards not including this in v1 since we don't have anything that knows how to deal with it. I still think having the dataclass makes sense, and we can append fields as we go (we can call this out in the docstring) -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
