AnandInguva commented on code in PR #29564:
URL: https://github.com/apache/beam/pull/29564#discussion_r1416277724
##########
sdks/python/apache_beam/ml/transforms/base.py:
##########
@@ -254,3 +371,243 @@ def _increment_counters():
pipeline
| beam.Create([None])
| beam.Map(lambda _: _increment_counters()))
+
+
+class _TransformAttributeManager:
+ """
+ Base class used for saving and loading the attributes.
+ """
+ @staticmethod
+ def save_attributes(artifact_location):
+ """
+ Save the attributes to json file using stdlib json.
+ """
+ raise NotImplementedError
+
+ @staticmethod
+ def load_attributes(artifact_location):
+ """
+ Load the attributes from json file.
+ """
+ raise NotImplementedError
+
+
+class _JsonPickleTransformAttributeManager(_TransformAttributeManager):
+ """
+ Use Jsonpickle to save and load the attributes. Here the attributes refer
+ to the list of PTransforms that are used to process the data.
+
+ jsonpickle is used to serialize the PTransforms and save it to a json file
and
+ is compatible across python versions.
+ """
+ @staticmethod
+ def _is_remote_path(path):
+ is_gcs = path.find('gs://') != -1
+ # TODO: Add support for other remote paths.
+ if not is_gcs and path.find('://') != -1:
+ raise RuntimeError(
+ "Artifact locations are currently supported for only available for "
+ "local paths and GCS paths. Got: %s" % path)
+ return is_gcs
+
+ @staticmethod
+ def save_attributes(
+ ptransform_list,
+ artifact_location,
+ **kwargs,
+ ):
+ if _JsonPickleTransformAttributeManager._is_remote_path(artifact_location):
+ try:
+ options = kwargs.get('options')
+ except KeyError:
+ raise RuntimeError(
+ 'pipeline options are required to save the attributes.'
+ 'in the artifact location %s' % artifact_location)
Review Comment:
Done. Removed try/catch
##########
sdks/python/apache_beam/ml/transforms/embeddings/sentence_transformer.py:
##########
@@ -0,0 +1,125 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+__all__ = ["SentenceTransformerEmbeddings"]
+
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import List
+from typing import Mapping
+from typing import Optional
+from typing import Sequence
+
+import apache_beam as beam
+from apache_beam.ml.inference.base import ModelHandler
+from apache_beam.ml.inference.base import RunInference
+from apache_beam.ml.transforms.base import EmbeddingsManager
+from apache_beam.ml.transforms.base import _TextEmbeddingHandler
+from sentence_transformers import SentenceTransformer
+
+
+# TODO: Use HuggingFaceModelHandlerTensor once the import issue is fixed.
Review Comment:
Added a todo link https://github.com/apache/beam/issues/29621
##########
sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py:
##########
@@ -0,0 +1,158 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# Vertex AI Python SDK is required for this module.
+# Follow
https://cloud.google.com/vertex-ai/docs/python-sdk/use-vertex-ai-python-sdk #
pylint: disable=line-too-long
+# to install Vertex AI Python SDK.
+
+from typing import Any
+from typing import Dict
+from typing import Iterable
+from typing import List
+from typing import Optional
+from typing import Sequence
+
+from google.auth.credentials import Credentials
+
+import apache_beam as beam
+import vertexai
+from apache_beam.ml.inference.base import ModelHandler
+from apache_beam.ml.inference.base import RunInference
+from apache_beam.ml.transforms.base import EmbeddingsManager
+from apache_beam.ml.transforms.base import _TextEmbeddingHandler
+from vertexai.language_models import TextEmbeddingInput
+from vertexai.language_models import TextEmbeddingModel
+
+__all__ = ["VertexAITextEmbeddings"]
+
+TASK_TYPE = "RETRIEVAL_DOCUMENT"
+TASK_TYPE_INPUTS = [
+ "RETRIEVAL_DOCUMENT",
+ "RETRIEVAL_QUERY",
+ "SEMANTIC_SIMILARITY",
+ "CLASSIFICATION",
+ "CLUSTERING"
+]
+
+
+class _VertexAITextEmbeddingHandler(ModelHandler):
+ """
+ Note: Intended for internal use and guarantees no backwards compatibility.
+ """
+ def __init__(
+ self,
+ model_name: str,
+ title: Optional[str] = None,
+ task_type: str = TASK_TYPE,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[Credentials] = None,
+ ):
+ vertexai.init(project=project, location=location, credentials=credentials)
+ self.model_name = model_name
+ if task_type not in TASK_TYPE_INPUTS:
+ raise ValueError(
+ f"task_type must be one of {TASK_TYPE_INPUTS}, got {task_type}")
Review Comment:
400 EMPTY_TASK cannot be parsed as a valid embedding task type. Valid task
types: [CLASSIFICATION, CLUSTERING, DEFAULT, RETRIEVAL_DOCUMENT,
RETRIEVAL_QUERY, SEMANTIC_SIMILARITY] [while running
'MLTransform/RunInference/BeamML_RunInference']
It errors out but after job submission. I created this so that we get faster
error.
I will see if we could automatically pull the list.
##########
sdks/python/apache_beam/ml/transforms/embeddings/sentence_transformer_test.py:
##########
@@ -0,0 +1,212 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import shutil
+import tempfile
+import unittest
+
+import apache_beam as beam
+from apache_beam.ml.transforms.base import MLTransform
+
+# pylint: disable=ungrouped-imports
+try:
+ from apache_beam.ml.transforms.embeddings.sentence_transformer import
SentenceTransformerEmbeddings
+except ImportError:
+ SentenceTransformerEmbeddings = None # type: ignore
+
+# pylint: disable=ungrouped-imports
+try:
+ import tensorflow_transform as tft
+ from apache_beam.ml.transforms.tft import ScaleTo01
+except ImportError:
+ tft = None
+
+test_query = "This is a test"
+test_query_column = "feature_1"
+DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
+
+
+def get_pipeline_wth_embedding_config(
Review Comment:
Removed this method. +1 if we could use our own ModelHandler but having to
download both pytorch and tensorflow at the same time is a big -1. I will link
an issue to the TODO.
##########
sdks/python/apache_beam/ml/transforms/utils.py:
##########
@@ -28,8 +30,13 @@ class ArtifactsFetcher():
to the TFTProcessHandlers in MLTransform.
"""
def __init__(self, artifact_location):
- self.artifact_location = artifact_location
- self.transform_output = tft.TFTransformOutput(self.artifact_location)
+ files = os.listdir(artifact_location)
+ files.remove(base._ATTRIBUTE_FILE_NAME)
+ if len(files) > 1:
+ raise NotImplementedError(
+ 'Multiple files in artifact location not supported yet.')
Review Comment:
User wouldn't know since we create a random directory inside artifact
location. I added more info on the NotImplementedError. One idea is to
integrate ArtifactsFetcher into MLTransform but can be addressed later
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]