josix commented on code in PR #41554:
URL: https://github.com/apache/airflow/pull/41554#discussion_r1723260101
##########
airflow/providers/openai/operators/openai.py:
##########
@@ -74,3 +78,92 @@ def execute(self, context: Context) -> list[float]:
embeddings = self.hook.create_embeddings(self.input_text,
model=self.model, **self.embedding_kwargs)
self.log.info("Generated embeddings for %d items", len(embeddings))
return embeddings
+
+
+class OpenAITriggerBatchOperator(BaseOperator):
+ """
+ Operator that triggers an OpenAI Batch API endpoint and waits for the
batch to complete.
+
+ :param file_id: Required. The ID of the batch file to trigger.
+ :param endpoint: Required. The OpenAI Batch API endpoint to trigger.
+ :param conn_id: Optional. The OpenAI connection ID to use. Defaults to
'openai_default'.
+ :param deferrable: Optional. Run operator in the deferrable mode.
+ :param wait_seconds: Optional. Number of seconds between checks. Only used
when ``deferrable`` is False.
+ Defaults to 3 seconds.
+ :param timeout: Optional. The amount of time, in seconds, to wait for the
request to complete.
+ Only used when ``deferrable`` is False. Defaults to 24 hour, which is
the SLA for OpenAI Batch API.
+ :param wait_for_completion: Optional. Whether to wait for the batch to
complete. If set to False, the operator
+ will return immediately after triggering the batch. Defaults to True.
+
+ .. seealso::
+ For more information on how to use this operator, please take a look
at the guide:
+ :ref:`howto/operator:OpenAITriggerBatchOperator`
+ """
+
+ template_fields: Sequence[str] = ("file_id",)
+
+ def __init__(
+ self,
+ file_id: str,
+ endpoint: Literal["/v1/chat/completions", "/v1/embeddings",
"/v1/completions"],
+ conn_id: str = OpenAIHook.default_conn_name,
+ deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
+ wait_seconds: float = 3,
+ timeout: float = 24 * 60 * 60,
+ wait_for_completion: bool = True,
+ **kwargs: Any,
+ ):
+ super().__init__(**kwargs)
+ self.conn_id = conn_id
+ self.file_id = file_id
+ self.endpoint = endpoint
+ self.deferrable = deferrable
+ self.wait_seconds = wait_seconds
+ self.timeout = timeout
+ self.wait_for_completion = wait_for_completion
+ self.batch_id: str | None = None
+
+ @cached_property
+ def hook(self) -> OpenAIHook:
+ """Return an instance of the OpenAIHook."""
+ return OpenAIHook(conn_id=self.conn_id)
+
+ def execute(self, context: Context) -> str:
+ batch = self.hook.create_batch(file_id=self.file_id,
endpoint=self.endpoint)
+ self.batch_id = batch.id
+ if not self.wait_for_completion:
+ return self.batch_id
Review Comment:
I move the check of whether to wait for completion as an outer check block,
and deferrable one in the in inner. Hope this would help others to understand
it. PTAL, thanks!
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]