AnandInguva commented on code in PR #30138:
URL: https://github.com/apache/beam/pull/30138#discussion_r1473248495
##########
sdks/python/apache_beam/ml/transforms/embeddings/huggingface.py:
##########
@@ -130,6 +142,88 @@ def get_model_handler(self):
def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform:
# wrap the model handler in a _TextEmbeddingHandler since
# the SentenceTransformerEmbeddings works on text input data.
- return RunInference(
- model_handler=_TextEmbeddingHandler(self),
- inference_args=self.inference_args)
+ return (
+ RunInference(
+ model_handler=_TextEmbeddingHandler(self),
+ inference_args=self.inference_args,
+ ))
+
+
+class _InferenceAPIHandler(ModelHandler):
+ def __init__(self, config: 'InferenceAPIEmbeddings'):
+ super().__init__()
+ self._config = config
+
+ def load_model(self):
+ session = requests.Session()
+ session.headers.update(self._config.authorization_token)
+ return session
+
+ def run_inference(
+ self, batch, session: requests.Session, inference_args=None):
+ response = session.post(
+ self._config.api_url,
+ headers=self._config.authorization_token,
+ json={
+ "inputs": batch, "options": inference_args
+ })
+ return response.json()
+
+
+class InferenceAPIEmbeddings(EmbeddingsManager):
+ def __init__(
+ self,
+ hf_token: str,
Review Comment:
Done
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]