josix commented on code in PR #41554:
URL: https://github.com/apache/airflow/pull/41554#discussion_r1723260101


##########
airflow/providers/openai/operators/openai.py:
##########
@@ -74,3 +78,92 @@ def execute(self, context: Context) -> list[float]:
         embeddings = self.hook.create_embeddings(self.input_text, 
model=self.model, **self.embedding_kwargs)
         self.log.info("Generated embeddings for %d items", len(embeddings))
         return embeddings
+
+
+class OpenAITriggerBatchOperator(BaseOperator):
+    """
+    Operator that triggers an OpenAI Batch API endpoint and waits for the 
batch to complete.
+
+    :param file_id: Required. The ID of the batch file to trigger.
+    :param endpoint: Required. The OpenAI Batch API endpoint to trigger.
+    :param conn_id: Optional. The OpenAI connection ID to use. Defaults to 
'openai_default'.
+    :param deferrable: Optional. Run operator in the deferrable mode.
+    :param wait_seconds: Optional. Number of seconds between checks. Only used 
when ``deferrable`` is False.
+        Defaults to 3 seconds.
+    :param timeout: Optional. The amount of time, in seconds, to wait for the 
request to complete.
+        Only used when ``deferrable`` is False. Defaults to 24 hour, which is 
the SLA for OpenAI Batch API.
+    :param wait_for_completion: Optional. Whether to wait for the batch to 
complete. If set to False, the operator
+        will return immediately after triggering the batch. Defaults to True.
+
+    .. seealso::
+        For more information on how to use this operator, please take a look 
at the guide:
+        :ref:`howto/operator:OpenAITriggerBatchOperator`
+    """
+
+    template_fields: Sequence[str] = ("file_id",)
+
+    def __init__(
+        self,
+        file_id: str,
+        endpoint: Literal["/v1/chat/completions", "/v1/embeddings", 
"/v1/completions"],
+        conn_id: str = OpenAIHook.default_conn_name,
+        deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
+        wait_seconds: float = 3,
+        timeout: float = 24 * 60 * 60,
+        wait_for_completion: bool = True,
+        **kwargs: Any,
+    ):
+        super().__init__(**kwargs)
+        self.conn_id = conn_id
+        self.file_id = file_id
+        self.endpoint = endpoint
+        self.deferrable = deferrable
+        self.wait_seconds = wait_seconds
+        self.timeout = timeout
+        self.wait_for_completion = wait_for_completion
+        self.batch_id: str | None = None
+
+    @cached_property
+    def hook(self) -> OpenAIHook:
+        """Return an instance of the OpenAIHook."""
+        return OpenAIHook(conn_id=self.conn_id)
+
+    def execute(self, context: Context) -> str:
+        batch = self.hook.create_batch(file_id=self.file_id, 
endpoint=self.endpoint)
+        self.batch_id = batch.id
+        if not self.wait_for_completion:
+            return self.batch_id

Review Comment:
   I move the check of whether to wait for completion as an outer check block, 
and deferrable one in the in inner. Hope this would help others to understand 
it. PTAL, thanks!



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to