riteshghorse commented on code in PR #30289:
URL: https://github.com/apache/beam/pull/30289#discussion_r1491502583


##########
sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub_test.py:
##########
@@ -0,0 +1,173 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import shutil
+import tempfile
+import unittest
+import uuid
+
+import apache_beam as beam
+from apache_beam.ml.transforms.base import MLTransform
+
+hub_url = 'https://tfhub.dev/google/nnlm-en-dim128/2'
+test_query_column = 'test_query'
+test_query = 'This is a test query'
+
+# pylint: disable=ungrouped-imports
+try:
+  import tensorflow as tf  # disable=unused-import
+  from apache_beam.ml.transforms.embeddings.tensorflow_hub import 
TensorflowHubTextEmbeddings
+except ImportError:
+  tf = None
+
+try:
+  from apache_beam.ml.transforms.tft import ScaleTo01
+except ImportError:
+  ScaleTo01 = None  # type: ignore

Review Comment:
   (Optional) Nit: may narrow the scope to `ignore[assignment]`



##########
sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub.py:
##########
@@ -0,0 +1,129 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Iterable
+from typing import List
+from typing import Optional
+
+import apache_beam as beam
+import tensorflow as tf
+import tensorflow_hub as hub
+import tensorflow_text as text  # required to register TF ops. # pylint: 
disable=unused-import
+from apache_beam.ml.inference import utils
+from apache_beam.ml.inference.base import ModelHandler
+from apache_beam.ml.inference.base import PredictionResult
+from apache_beam.ml.inference.base import RunInference
+from apache_beam.ml.inference.tensorflow_inference import TFModelHandlerTensor
+from apache_beam.ml.inference.tensorflow_inference import 
default_tensor_inference_fn
+from apache_beam.ml.transforms.base import EmbeddingsManager
+from apache_beam.ml.transforms.base import _TextEmbeddingHandler
+
+__all__ = ['TensorflowHubTextEmbeddings']
+
+
+# TODO: https://github.com/apache/beam/issues/30288
+# Replace with TFModelHandlerTensor when load_model() supports TFHUB models.
+class _TensorflowHubModelHandler(TFModelHandlerTensor):
+  """
+  Note: Intended for internal use only. No backwards compatibility guarantees.
+  """
+  def __init__(self, preprocessing_url: Optional[str], *args, **kwargs):
+    self.preprocessing_url = preprocessing_url
+    super().__init__(*args, **kwargs)
+
+  def load_model(self):
+    # unable to load the models with tf.keras.models.load_model so
+    # using hub.KerasLayer instead
+    model = hub.KerasLayer(self._model_uri, **self._load_model_args)
+    return model
+
+  def _convert_prediction_result_to_list(
+      self, predictions: Iterable[PredictionResult]):
+    result = []
+    for prediction in predictions:
+      inference = prediction.inference.numpy().tolist()
+      result.append(inference)
+    return result
+
+  def run_inference(self, batch, model, inference_args, model_id=None):
+    if not inference_args:
+      inference_args = {}
+    if not self.preprocessing_url:
+      predictions = default_tensor_inference_fn(
+          model=model,
+          batch=batch,
+          inference_args=inference_args,
+          model_id=model_id)
+      return self._convert_prediction_result_to_list(predictions)
+
+    vectorized_batch = tf.stack(batch, axis=0)
+    preprocessor_fn = hub.KerasLayer(self.preprocessing_url)
+    vectorized_batch = preprocessor_fn(vectorized_batch)
+    predictions = model(vectorized_batch)
+    # 
https://www.tensorflow.org/text/tutorials/classify_text_with_bert#using_the_bert_model
 # pylint: disable=line-too-long
+    # pooled_output -> represents the text as a whole. This is an embeddings
+    # of the whole text. The shape is [batch_size, embedding_dimension]
+    # sequence_output -> represents the text as a sequence of tokens. This is
+    # an embeddings of each token in the text. The shape is
+    # [batch_size, max_sequence_length, embedding_dimension]
+    # pooled output is the embeedings as per the documentation. so let's use
+    # that.
+    embeddings = predictions['pooled_output']
+    predictions = utils._convert_to_result(batch, embeddings, model_id)
+    return self._convert_prediction_result_to_list(predictions)
+
+
+class TensorflowHubTextEmbeddings(EmbeddingsManager):
+  def __init__(
+      self,
+      columns: List[str],
+      hub_url: str,
+      preprocessing_url: Optional[str] = None,
+      **kwargs):
+    """
+    Embedding config for tensorflow hub models. This config can be used with
+    MLTransform to embed text data. Models are loaded using the RunInference
+    PTransform with the help of a ModelHandler.
+    Args:
+      columns: The columns containing the text to be embedded.
+      hub_url: The url of the tensorflow hub model.
+      preprocessing_url: The url of the preprocessing model. This is optional.
+        If provided, the preprocessing model will be used to preprocess the
+        text before feeding it to the main model.
+      min_batch_size: The minimum batch size to be used for inference.
+      max_batch_size: The maximum batch size to be used for inference.
+      large_model: Whether to share the model across processes.
+    """
+    super().__init__(columns=columns, **kwargs)
+    self.model_uri = hub_url
+    self.preprocessing_url = preprocessing_url
+
+  def get_model_handler(self) -> ModelHandler:
+    # override the default inference function
+    return _TensorflowHubModelHandler(
+        model_uri=self.model_uri,
+        preprocessing_url=self.preprocessing_url,
+        min_batch_size=self.min_batch_size,
+        max_batch_size=self.max_batch_size,
+        large_model=self.large_model,
+    )
+
+  def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform:

Review Comment:
   Nit: consider adding a doc comment



##########
sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub.py:
##########
@@ -0,0 +1,129 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Iterable
+from typing import List
+from typing import Optional
+
+import apache_beam as beam
+import tensorflow as tf
+import tensorflow_hub as hub
+import tensorflow_text as text  # required to register TF ops. # pylint: 
disable=unused-import
+from apache_beam.ml.inference import utils
+from apache_beam.ml.inference.base import ModelHandler
+from apache_beam.ml.inference.base import PredictionResult
+from apache_beam.ml.inference.base import RunInference
+from apache_beam.ml.inference.tensorflow_inference import TFModelHandlerTensor
+from apache_beam.ml.inference.tensorflow_inference import 
default_tensor_inference_fn
+from apache_beam.ml.transforms.base import EmbeddingsManager
+from apache_beam.ml.transforms.base import _TextEmbeddingHandler
+
+__all__ = ['TensorflowHubTextEmbeddings']
+
+
+# TODO: https://github.com/apache/beam/issues/30288
+# Replace with TFModelHandlerTensor when load_model() supports TFHUB models.
+class _TensorflowHubModelHandler(TFModelHandlerTensor):
+  """
+  Note: Intended for internal use only. No backwards compatibility guarantees.
+  """
+  def __init__(self, preprocessing_url: Optional[str], *args, **kwargs):
+    self.preprocessing_url = preprocessing_url
+    super().__init__(*args, **kwargs)
+
+  def load_model(self):
+    # unable to load the models with tf.keras.models.load_model so
+    # using hub.KerasLayer instead
+    model = hub.KerasLayer(self._model_uri, **self._load_model_args)
+    return model
+
+  def _convert_prediction_result_to_list(
+      self, predictions: Iterable[PredictionResult]):
+    result = []
+    for prediction in predictions:
+      inference = prediction.inference.numpy().tolist()
+      result.append(inference)
+    return result
+
+  def run_inference(self, batch, model, inference_args, model_id=None):
+    if not inference_args:
+      inference_args = {}
+    if not self.preprocessing_url:
+      predictions = default_tensor_inference_fn(
+          model=model,
+          batch=batch,
+          inference_args=inference_args,
+          model_id=model_id)
+      return self._convert_prediction_result_to_list(predictions)
+
+    vectorized_batch = tf.stack(batch, axis=0)
+    preprocessor_fn = hub.KerasLayer(self.preprocessing_url)
+    vectorized_batch = preprocessor_fn(vectorized_batch)
+    predictions = model(vectorized_batch)
+    # 
https://www.tensorflow.org/text/tutorials/classify_text_with_bert#using_the_bert_model
 # pylint: disable=line-too-long
+    # pooled_output -> represents the text as a whole. This is an embeddings
+    # of the whole text. The shape is [batch_size, embedding_dimension]
+    # sequence_output -> represents the text as a sequence of tokens. This is
+    # an embeddings of each token in the text. The shape is
+    # [batch_size, max_sequence_length, embedding_dimension]
+    # pooled output is the embeedings as per the documentation. so let's use
+    # that.
+    embeddings = predictions['pooled_output']
+    predictions = utils._convert_to_result(batch, embeddings, model_id)
+    return self._convert_prediction_result_to_list(predictions)
+
+
+class TensorflowHubTextEmbeddings(EmbeddingsManager):
+  def __init__(
+      self,
+      columns: List[str],
+      hub_url: str,
+      preprocessing_url: Optional[str] = None,
+      **kwargs):
+    """
+    Embedding config for tensorflow hub models. This config can be used with
+    MLTransform to embed text data. Models are loaded using the RunInference
+    PTransform with the help of a ModelHandler.
+    Args:
+      columns: The columns containing the text to be embedded.
+      hub_url: The url of the tensorflow hub model.
+      preprocessing_url: The url of the preprocessing model. This is optional.
+        If provided, the preprocessing model will be used to preprocess the
+        text before feeding it to the main model.
+      min_batch_size: The minimum batch size to be used for inference.
+      max_batch_size: The maximum batch size to be used for inference.
+      large_model: Whether to share the model across processes.
+    """

Review Comment:
   Should we specify that these args should be passed with `kwargs`?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to