AnandInguva commented on code in PR #30138:
URL: https://github.com/apache/beam/pull/30138#discussion_r1473230698


##########
sdks/python/apache_beam/ml/transforms/embeddings/huggingface.py:
##########
@@ -130,6 +142,88 @@ def get_model_handler(self):
   def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform:
     # wrap the model handler in a _TextEmbeddingHandler since
     # the SentenceTransformerEmbeddings works on text input data.
-    return RunInference(
-        model_handler=_TextEmbeddingHandler(self),
-        inference_args=self.inference_args)
+    return (
+        RunInference(
+            model_handler=_TextEmbeddingHandler(self),
+            inference_args=self.inference_args,
+        ))
+
+
+class _InferenceAPIHandler(ModelHandler):
+  def __init__(self, config: 'InferenceAPIEmbeddings'):
+    super().__init__()
+    self._config = config
+
+  def load_model(self):
+    session = requests.Session()
+    session.headers.update(self._config.authorization_token)
+    return session
+
+  def run_inference(
+      self, batch, session: requests.Session, inference_args=None):
+    response = session.post(
+        self._config.api_url,
+        headers=self._config.authorization_token,
+        json={
+            "inputs": batch, "options": inference_args
+        })
+    return response.json()
+
+
+class InferenceAPIEmbeddings(EmbeddingsManager):
+  def __init__(
+      self,
+      hf_token: str,
+      columns: List[str],
+      model_name: Optional[str] = None, # example: 
"sentence-transformers/all-MiniLM-l6-v2" # pylint: disable=line-too-long
+      api_url: Optional[str] = None,
+      **kwargs,
+      ):
+    """
+    Feature extraction using HuggingFace's Inference API.
+    Intended to be used for feature-extraction. For other tasks, please
+    refer to https://huggingface.co/inference-api.
+    Args:
+      hf_token: HuggingFace token.
+      columns: List of columns to be embedded.
+      model_name: Model name used for feature extraction.
+      api_url: API url for feature extraction. If specified, model_name will be
+        ignored. If none, the default url
+        
https://api-inference.huggingface.co/pipeline/feature-extraction/{model_name} # 
pylint: disable=line-too-long
+        will be used.
+
+    """
+    super().__init__(columns, **kwargs)
+    self._api_url = api_url
+    self._authorization_token = {"Authorization": f"Bearer {hf_token}"}
+    self._model_name = model_name

Review Comment:
   Done



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to