vincbeck commented on code in PR #41554:
URL: https://github.com/apache/airflow/pull/41554#discussion_r1722284792
##########
airflow/providers/openai/operators/openai.py:
##########
@@ -74,3 +78,86 @@ def execute(self, context: Context) -> list[float]:
embeddings = self.hook.create_embeddings(self.input_text,
model=self.model, **self.embedding_kwargs)
self.log.info("Generated embeddings for %d items", len(embeddings))
return embeddings
+
+
+class OpenAITriggerBatchOperator(BaseOperator):
+ """
+ Operator that triggers an OpenAI Batch API endpoint and waits for the
batch to complete.
+
+ :param file_id: Required. The ID of the batch file to trigger.
+ :param endpoint: Required. The OpenAI Batch API endpoint to trigger.
+ :param conn_id: Optional. The OpenAI connection ID to use. Defaults to
'openai_default'.
+ :param deferrable: Optional. Run operator in the deferrable mode.
+ :param wait_seconds: Optional. Number of seconds between checks. Only used
when ``deferrable`` is False.
+ Defaults to 3 seconds.
+ :param timeout: Optional. The amount of time, in seconds, to wait for the
request to complete.
+ Only used when ``deferrable`` is False. Defaults to 24 hour, which is
the SLA for OpenAI Batch API.
+
+ .. seealso::
+ For more information on how to use this operator, please take a look
at the guide:
+ :ref:`howto/operator:OpenAITriggerBatchOperator`
+ """
+
+ template_fields: Sequence[str] = ("file_id",)
+
+ def __init__(
+ self,
+ file_id: str,
+ endpoint: Literal["/v1/chat/completions", "/v1/embeddings",
"/v1/completions"],
+ conn_id: str = OpenAIHook.default_conn_name,
+ deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
+ wait_seconds: float = 3,
+ timeout: float = 24 * 60 * 60,
+ **kwargs: Any,
+ ):
+ super().__init__(**kwargs)
+ self.conn_id = conn_id
+ self.file_id = file_id
+ self.endpoint = endpoint
+ self.deferrable = deferrable
+ self.wait_seconds = wait_seconds
+ self.timeout = timeout
+ self.batch_id: str | None = None
+
+ @cached_property
+ def hook(self) -> OpenAIHook:
+ """Return an instance of the OpenAIHook."""
+ return OpenAIHook(conn_id=self.conn_id)
+
+ def execute(self, context: Context) -> str:
+ batch = self.hook.create_batch(file_id=self.file_id,
endpoint=self.endpoint)
+ self.batch_id = batch.id
+ if self.deferrable:
+ self.defer(
+ timeout=self.execution_timeout,
+ trigger=OpenAIBatchTrigger(
+ conn_id=self.conn_id,
+ batch_id=self.batch_id,
+ poll_interval=60,
+ end_time=time.time() + self.timeout,
+ ),
+ method_name="execute_complete",
+ )
+ else:
+ self.log.info("Waiting for batch %s to complete", self.batch_id)
+ self.hook.wait_for_batch(self.batch_id,
wait_seconds=self.wait_seconds, timeout=self.timeout)
Review Comment:
I'll leave it to you that decision but usually in Airflow we use a flag
`wait_for_completion` which configure whether the operator waits. In your case
I think that can be useful, some users might want to trigger a batch without
waiting?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]