damccorm commented on code in PR #29564: URL: https://github.com/apache/beam/pull/29564#discussion_r1414418545
########## sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub.py: ########## @@ -0,0 +1,124 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Iterable +from typing import List +from typing import Optional + +import apache_beam as beam +import tensorflow as tf +import tensorflow_hub as hub +import tensorflow_text as text # required to register TF ops. # pylint: disable=unused-import +from apache_beam.ml.inference import utils +from apache_beam.ml.inference.base import ModelHandler +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.inference.tensorflow_inference import TFModelHandlerTensor +from apache_beam.ml.inference.tensorflow_inference import default_tensor_inference_fn +from apache_beam.ml.transforms.base import EmbeddingsManager +from apache_beam.ml.transforms.base import _TextEmbeddingHandler + +__all__ = ['TensorflowHubTextEmbeddings'] + + +class _TensorflowHubModelHandler(TFModelHandlerTensor): + """ + Note: Intended for internal use only. No backwards compatibility guarantees. + """ + def __init__(self, preprocessing_url: Optional[str], *args, **kwargs): + self.preprocessing_url = preprocessing_url + super().__init__(*args, **kwargs) + + def load_model(self): + # unable to load the models with tf.keras.models.load_model so Review Comment: Why is this? ########## sdks/python/apache_beam/ml/transforms/embeddings/sentence_transformer_test.py: ########## @@ -0,0 +1,212 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import shutil +import tempfile +import unittest + +import apache_beam as beam +from apache_beam.ml.transforms.base import MLTransform + +# pylint: disable=ungrouped-imports +try: + from apache_beam.ml.transforms.embeddings.sentence_transformer import SentenceTransformerEmbeddings +except ImportError: + SentenceTransformerEmbeddings = None # type: ignore + +# pylint: disable=ungrouped-imports +try: + import tensorflow_transform as tft + from apache_beam.ml.transforms.tft import ScaleTo01 +except ImportError: + tft = None + +test_query = "This is a test" +test_query_column = "feature_1" +DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2" + + +def get_pipeline_wth_embedding_config( Review Comment: Nit: this helper is only called once and can probably be inlined. Though adding some more configs wouldn't be terrible if we're using our own model handler. I'd less worried if we could just reuse the existing Hugging Face Model Handlers since our testable surface would be smaller. ########## sdks/python/apache_beam/ml/transforms/utils.py: ########## @@ -28,8 +30,13 @@ class ArtifactsFetcher(): to the TFTProcessHandlers in MLTransform. """ def __init__(self, artifact_location): - self.artifact_location = artifact_location - self.transform_output = tft.TFTransformOutput(self.artifact_location) + files = os.listdir(artifact_location) + files.remove(base._ATTRIBUTE_FILE_NAME) + if len(files) > 1: + raise NotImplementedError( + 'Multiple files in artifact location not supported yet.') Review Comment: Could we be a little more instructive - something like: "Multiple files in artifact location not supported yet. Found %s files. Please specify a location with a single file." ########## sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub.py: ########## @@ -0,0 +1,124 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Iterable +from typing import List +from typing import Optional + +import apache_beam as beam +import tensorflow as tf +import tensorflow_hub as hub +import tensorflow_text as text # required to register TF ops. # pylint: disable=unused-import +from apache_beam.ml.inference import utils +from apache_beam.ml.inference.base import ModelHandler +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.inference.tensorflow_inference import TFModelHandlerTensor +from apache_beam.ml.inference.tensorflow_inference import default_tensor_inference_fn +from apache_beam.ml.transforms.base import EmbeddingsManager +from apache_beam.ml.transforms.base import _TextEmbeddingHandler + +__all__ = ['TensorflowHubTextEmbeddings'] + + +class _TensorflowHubModelHandler(TFModelHandlerTensor): + """ + Note: Intended for internal use only. No backwards compatibility guarantees. + """ + def __init__(self, preprocessing_url: Optional[str], *args, **kwargs): + self.preprocessing_url = preprocessing_url + super().__init__(*args, **kwargs) + + def load_model(self): + # unable to load the models with tf.keras.models.load_model so Review Comment: Note, this also keeps us from being able to just use TFModelHandlerTensor with a custom inference function (unclear if that's cleaner or not) ########## sdks/python/apache_beam/ml/transforms/embeddings/sentence_transformer.py: ########## @@ -0,0 +1,125 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = ["SentenceTransformerEmbeddings"] + +from typing import Any +from typing import Callable +from typing import Dict +from typing import List +from typing import Mapping +from typing import Optional +from typing import Sequence + +import apache_beam as beam +from apache_beam.ml.inference.base import ModelHandler +from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.transforms.base import EmbeddingsManager +from apache_beam.ml.transforms.base import _TextEmbeddingHandler +from sentence_transformers import SentenceTransformer + + +# TODO: Use HuggingFaceModelHandlerTensor once the import issue is fixed. Review Comment: Which import issue is this referencing? Is it that we have top level imports of both tensorflow and pytorch (and types that depend on them)? Could you link to an issue (we should create one if we don't already have one). ########## sdks/python/apache_beam/ml/transforms/embeddings/sentence_transformer.py: ########## @@ -0,0 +1,125 @@ +# Review Comment: Could we name this file hugging_face.py so that the import is `apache_beam.<...>.hugging_face. SentenceTransformerEmbeddings`? ########## sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py: ########## @@ -0,0 +1,158 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Vertex AI Python SDK is required for this module. +# Follow https://cloud.google.com/vertex-ai/docs/python-sdk/use-vertex-ai-python-sdk # pylint: disable=line-too-long +# to install Vertex AI Python SDK. + +from typing import Any +from typing import Dict +from typing import Iterable +from typing import List +from typing import Optional +from typing import Sequence + +from google.auth.credentials import Credentials + +import apache_beam as beam +import vertexai +from apache_beam.ml.inference.base import ModelHandler +from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.transforms.base import EmbeddingsManager +from apache_beam.ml.transforms.base import _TextEmbeddingHandler +from vertexai.language_models import TextEmbeddingInput +from vertexai.language_models import TextEmbeddingModel + +__all__ = ["VertexAITextEmbeddings"] + +TASK_TYPE = "RETRIEVAL_DOCUMENT" Review Comment: Maybe `DEFAULT_TASK_TYPE`? ########## sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py: ########## @@ -0,0 +1,158 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Vertex AI Python SDK is required for this module. +# Follow https://cloud.google.com/vertex-ai/docs/python-sdk/use-vertex-ai-python-sdk # pylint: disable=line-too-long +# to install Vertex AI Python SDK. + +from typing import Any +from typing import Dict +from typing import Iterable +from typing import List +from typing import Optional +from typing import Sequence + +from google.auth.credentials import Credentials + +import apache_beam as beam +import vertexai +from apache_beam.ml.inference.base import ModelHandler +from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.transforms.base import EmbeddingsManager +from apache_beam.ml.transforms.base import _TextEmbeddingHandler +from vertexai.language_models import TextEmbeddingInput +from vertexai.language_models import TextEmbeddingModel + +__all__ = ["VertexAITextEmbeddings"] + +TASK_TYPE = "RETRIEVAL_DOCUMENT" +TASK_TYPE_INPUTS = [ + "RETRIEVAL_DOCUMENT", + "RETRIEVAL_QUERY", + "SEMANTIC_SIMILARITY", + "CLASSIFICATION", + "CLUSTERING" +] + + +class _VertexAITextEmbeddingHandler(ModelHandler): + """ + Note: Intended for internal use and guarantees no backwards compatibility. + """ + def __init__( + self, + model_name: str, + title: Optional[str] = None, + task_type: str = TASK_TYPE, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[Credentials] = None, + ): + vertexai.init(project=project, location=location, credentials=credentials) + self.model_name = model_name + if task_type not in TASK_TYPE_INPUTS: + raise ValueError( + f"task_type must be one of {TASK_TYPE_INPUTS}, got {task_type}") Review Comment: What happens if we don't throw here. Does this eventually error? I wonder if there's a way for us to check if this is a valid task type using vertex's service instead of coding it ourselves. That way if the supported task list changes we don't need changes on our end. ########## sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py: ########## @@ -0,0 +1,158 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Vertex AI Python SDK is required for this module. +# Follow https://cloud.google.com/vertex-ai/docs/python-sdk/use-vertex-ai-python-sdk # pylint: disable=line-too-long +# to install Vertex AI Python SDK. + +from typing import Any +from typing import Dict +from typing import Iterable +from typing import List +from typing import Optional +from typing import Sequence + +from google.auth.credentials import Credentials + +import apache_beam as beam +import vertexai +from apache_beam.ml.inference.base import ModelHandler +from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.transforms.base import EmbeddingsManager +from apache_beam.ml.transforms.base import _TextEmbeddingHandler +from vertexai.language_models import TextEmbeddingInput +from vertexai.language_models import TextEmbeddingModel + +__all__ = ["VertexAITextEmbeddings"] + +TASK_TYPE = "RETRIEVAL_DOCUMENT" +TASK_TYPE_INPUTS = [ + "RETRIEVAL_DOCUMENT", + "RETRIEVAL_QUERY", + "SEMANTIC_SIMILARITY", + "CLASSIFICATION", + "CLUSTERING" +] + + +class _VertexAITextEmbeddingHandler(ModelHandler): + """ + Note: Intended for internal use and guarantees no backwards compatibility. + """ + def __init__( + self, + model_name: str, + title: Optional[str] = None, + task_type: str = TASK_TYPE, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[Credentials] = None, + ): + vertexai.init(project=project, location=location, credentials=credentials) + self.model_name = model_name + if task_type not in TASK_TYPE_INPUTS: + raise ValueError( + f"task_type must be one of {TASK_TYPE_INPUTS}, got {task_type}") + self.task_type = task_type + self.title = title + + def run_inference( + self, + batch: Sequence[str], + model: Any, + inference_args: Optional[Dict[str, Any]] = None, + ) -> Iterable: + embeddings = [] + batch_size = 5 # Vertex AI limits requests to 5 at a time. Review Comment: Could we make this a constant defined next to TASK_TYPE? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
