damccorm commented on code in PR #30138:
URL: https://github.com/apache/beam/pull/30138#discussion_r1470022638
##########
sdks/python/apache_beam/ml/transforms/embeddings/huggingface_test.py:
##########
@@ -44,6 +45,7 @@
except ImportError:
tft = None
+_HF_TOKEN = os.environ.get('HF_INFERENCE_TOKEN')
Review Comment:
Can we add this to our actions yaml file so that we verify this runs?
##########
sdks/python/apache_beam/ml/transforms/embeddings/huggingface.py:
##########
@@ -130,6 +142,88 @@ def get_model_handler(self):
def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform:
# wrap the model handler in a _TextEmbeddingHandler since
# the SentenceTransformerEmbeddings works on text input data.
- return RunInference(
- model_handler=_TextEmbeddingHandler(self),
- inference_args=self.inference_args)
+ return (
+ RunInference(
+ model_handler=_TextEmbeddingHandler(self),
+ inference_args=self.inference_args,
+ ))
+
+
+class _InferenceAPIHandler(ModelHandler):
+ def __init__(self, config: 'InferenceAPIEmbeddings'):
+ super().__init__()
+ self._config = config
+
+ def load_model(self):
+ session = requests.Session()
+ session.headers.update(self._config.authorization_token)
+ return session
+
+ def run_inference(
+ self, batch, session: requests.Session, inference_args=None):
+ response = session.post(
+ self._config.api_url,
+ headers=self._config.authorization_token,
+ json={
+ "inputs": batch, "options": inference_args
+ })
+ return response.json()
+
+
+class InferenceAPIEmbeddings(EmbeddingsManager):
+ def __init__(
+ self,
+ hf_token: str,
+ columns: List[str],
+ model_name: Optional[str] = None, # example:
"sentence-transformers/all-MiniLM-l6-v2" # pylint: disable=line-too-long
+ api_url: Optional[str] = None,
+ **kwargs,
+ ):
+ """
+ Feature extraction using HuggingFace's Inference API.
+ Intended to be used for feature-extraction. For other tasks, please
+ refer to https://huggingface.co/inference-api.
+ Args:
+ hf_token: HuggingFace token.
+ columns: List of columns to be embedded.
+ model_name: Model name used for feature extraction.
+ api_url: API url for feature extraction. If specified, model_name will be
+ ignored. If none, the default url
+
https://api-inference.huggingface.co/pipeline/feature-extraction/{model_name} #
pylint: disable=line-too-long
+ will be used.
+
+ """
+ super().__init__(columns, **kwargs)
+ self._api_url = api_url
+ self._authorization_token = {"Authorization": f"Bearer {hf_token}"}
+ self._model_name = model_name
Review Comment:
We should probably do some validation that exactly one of model_name and
api_url are set here instead of in api_url so that this throws at construction
time instead of runtime
##########
sdks/python/apache_beam/ml/transforms/embeddings/huggingface.py:
##########
@@ -130,6 +142,88 @@ def get_model_handler(self):
def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform:
# wrap the model handler in a _TextEmbeddingHandler since
# the SentenceTransformerEmbeddings works on text input data.
- return RunInference(
- model_handler=_TextEmbeddingHandler(self),
- inference_args=self.inference_args)
+ return (
+ RunInference(
+ model_handler=_TextEmbeddingHandler(self),
+ inference_args=self.inference_args,
+ ))
+
+
+class _InferenceAPIHandler(ModelHandler):
+ def __init__(self, config: 'InferenceAPIEmbeddings'):
+ super().__init__()
+ self._config = config
+
+ def load_model(self):
+ session = requests.Session()
+ session.headers.update(self._config.authorization_token)
+ return session
+
+ def run_inference(
+ self, batch, session: requests.Session, inference_args=None):
+ response = session.post(
+ self._config.api_url,
+ headers=self._config.authorization_token,
+ json={
+ "inputs": batch, "options": inference_args
+ })
+ return response.json()
+
+
+class InferenceAPIEmbeddings(EmbeddingsManager):
+ def __init__(
+ self,
+ hf_token: str,
Review Comment:
Should we make this optional? It looks like this kind of authorization can
also be set in the environment -
https://huggingface.co/docs/huggingface_hub/package_reference/environment_variables#hftoken
- we could try to pull the token using that env variable if this isn't set.
The downside is that it makes it a runtime error, but it would probably be a
more robust/secure approach if users have custom containers.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]