vincbeck commented on code in PR #41554:
URL: https://github.com/apache/airflow/pull/41554#discussion_r1722284792


##########
airflow/providers/openai/operators/openai.py:
##########
@@ -74,3 +78,86 @@ def execute(self, context: Context) -> list[float]:
         embeddings = self.hook.create_embeddings(self.input_text, 
model=self.model, **self.embedding_kwargs)
         self.log.info("Generated embeddings for %d items", len(embeddings))
         return embeddings
+
+
+class OpenAITriggerBatchOperator(BaseOperator):
+    """
+    Operator that triggers an OpenAI Batch API endpoint and waits for the 
batch to complete.
+
+    :param file_id: Required. The ID of the batch file to trigger.
+    :param endpoint: Required. The OpenAI Batch API endpoint to trigger.
+    :param conn_id: Optional. The OpenAI connection ID to use. Defaults to 
'openai_default'.
+    :param deferrable: Optional. Run operator in the deferrable mode.
+    :param wait_seconds: Optional. Number of seconds between checks. Only used 
when ``deferrable`` is False.
+        Defaults to 3 seconds.
+    :param timeout: Optional. The amount of time, in seconds, to wait for the 
request to complete.
+        Only used when ``deferrable`` is False. Defaults to 24 hour, which is 
the SLA for OpenAI Batch API.
+
+    .. seealso::
+        For more information on how to use this operator, please take a look 
at the guide:
+        :ref:`howto/operator:OpenAITriggerBatchOperator`
+    """
+
+    template_fields: Sequence[str] = ("file_id",)
+
+    def __init__(
+        self,
+        file_id: str,
+        endpoint: Literal["/v1/chat/completions", "/v1/embeddings", 
"/v1/completions"],
+        conn_id: str = OpenAIHook.default_conn_name,
+        deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
+        wait_seconds: float = 3,
+        timeout: float = 24 * 60 * 60,
+        **kwargs: Any,
+    ):
+        super().__init__(**kwargs)
+        self.conn_id = conn_id
+        self.file_id = file_id
+        self.endpoint = endpoint
+        self.deferrable = deferrable
+        self.wait_seconds = wait_seconds
+        self.timeout = timeout
+        self.batch_id: str | None = None
+
+    @cached_property
+    def hook(self) -> OpenAIHook:
+        """Return an instance of the OpenAIHook."""
+        return OpenAIHook(conn_id=self.conn_id)
+
+    def execute(self, context: Context) -> str:
+        batch = self.hook.create_batch(file_id=self.file_id, 
endpoint=self.endpoint)
+        self.batch_id = batch.id
+        if self.deferrable:
+            self.defer(
+                timeout=self.execution_timeout,
+                trigger=OpenAIBatchTrigger(
+                    conn_id=self.conn_id,
+                    batch_id=self.batch_id,
+                    poll_interval=60,
+                    end_time=time.time() + self.timeout,
+                ),
+                method_name="execute_complete",
+            )
+        else:
+            self.log.info("Waiting for batch %s to complete", self.batch_id)
+            self.hook.wait_for_batch(self.batch_id, 
wait_seconds=self.wait_seconds, timeout=self.timeout)

Review Comment:
   I'll leave it to you that decision but usually in Airflow we use a flag 
`wait_for_completion` which configure whether the operator waits. In your case 
I think that can be useful, some users might want to trigger a batch without 
waiting?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to