riteshghorse commented on code in PR #30289: URL: https://github.com/apache/beam/pull/30289#discussion_r1491837512
########## sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub.py: ########## @@ -0,0 +1,129 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Iterable +from typing import List +from typing import Optional + +import apache_beam as beam +import tensorflow as tf +import tensorflow_hub as hub +import tensorflow_text as text # required to register TF ops. # pylint: disable=unused-import +from apache_beam.ml.inference import utils +from apache_beam.ml.inference.base import ModelHandler +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.inference.tensorflow_inference import TFModelHandlerTensor +from apache_beam.ml.inference.tensorflow_inference import default_tensor_inference_fn +from apache_beam.ml.transforms.base import EmbeddingsManager +from apache_beam.ml.transforms.base import _TextEmbeddingHandler + +__all__ = ['TensorflowHubTextEmbeddings'] + + +# TODO: https://github.com/apache/beam/issues/30288 +# Replace with TFModelHandlerTensor when load_model() supports TFHUB models. +class _TensorflowHubModelHandler(TFModelHandlerTensor): + """ + Note: Intended for internal use only. No backwards compatibility guarantees. + """ + def __init__(self, preprocessing_url: Optional[str], *args, **kwargs): + self.preprocessing_url = preprocessing_url + super().__init__(*args, **kwargs) + + def load_model(self): + # unable to load the models with tf.keras.models.load_model so + # using hub.KerasLayer instead + model = hub.KerasLayer(self._model_uri, **self._load_model_args) + return model + + def _convert_prediction_result_to_list( + self, predictions: Iterable[PredictionResult]): + result = [] + for prediction in predictions: + inference = prediction.inference.numpy().tolist() + result.append(inference) + return result + + def run_inference(self, batch, model, inference_args, model_id=None): + if not inference_args: + inference_args = {} + if not self.preprocessing_url: + predictions = default_tensor_inference_fn( + model=model, + batch=batch, + inference_args=inference_args, + model_id=model_id) + return self._convert_prediction_result_to_list(predictions) + + vectorized_batch = tf.stack(batch, axis=0) + preprocessor_fn = hub.KerasLayer(self.preprocessing_url) + vectorized_batch = preprocessor_fn(vectorized_batch) + predictions = model(vectorized_batch) + # https://www.tensorflow.org/text/tutorials/classify_text_with_bert#using_the_bert_model # pylint: disable=line-too-long + # pooled_output -> represents the text as a whole. This is an embeddings + # of the whole text. The shape is [batch_size, embedding_dimension] + # sequence_output -> represents the text as a sequence of tokens. This is + # an embeddings of each token in the text. The shape is + # [batch_size, max_sequence_length, embedding_dimension] + # pooled output is the embeedings as per the documentation. so let's use + # that. + embeddings = predictions['pooled_output'] + predictions = utils._convert_to_result(batch, embeddings, model_id) + return self._convert_prediction_result_to_list(predictions) + + +class TensorflowHubTextEmbeddings(EmbeddingsManager): + def __init__( + self, + columns: List[str], + hub_url: str, + preprocessing_url: Optional[str] = None, + **kwargs): + """ + Embedding config for tensorflow hub models. This config can be used with + MLTransform to embed text data. Models are loaded using the RunInference + PTransform with the help of a ModelHandler. + Args: + columns: The columns containing the text to be embedded. + hub_url: The url of the tensorflow hub model. + preprocessing_url: The url of the preprocessing model. This is optional. + If provided, the preprocessing model will be used to preprocess the + text before feeding it to the main model. + min_batch_size: The minimum batch size to be used for inference. + max_batch_size: The maximum batch size to be used for inference. + large_model: Whether to share the model across processes. Review Comment: are these part of kwargs or separate args? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
