josix commented on code in PR #41554:
URL: https://github.com/apache/airflow/pull/41554#discussion_r1722639512


##########
airflow/providers/openai/operators/openai.py:
##########
@@ -74,3 +78,86 @@ def execute(self, context: Context) -> list[float]:
         embeddings = self.hook.create_embeddings(self.input_text, 
model=self.model, **self.embedding_kwargs)
         self.log.info("Generated embeddings for %d items", len(embeddings))
         return embeddings
+
+
+class OpenAITriggerBatchOperator(BaseOperator):
+    """
+    Operator that triggers an OpenAI Batch API endpoint and waits for the 
batch to complete.
+
+    :param file_id: Required. The ID of the batch file to trigger.
+    :param endpoint: Required. The OpenAI Batch API endpoint to trigger.
+    :param conn_id: Optional. The OpenAI connection ID to use. Defaults to 
'openai_default'.
+    :param deferrable: Optional. Run operator in the deferrable mode.
+    :param wait_seconds: Optional. Number of seconds between checks. Only used 
when ``deferrable`` is False.
+        Defaults to 3 seconds.
+    :param timeout: Optional. The amount of time, in seconds, to wait for the 
request to complete.
+        Only used when ``deferrable`` is False. Defaults to 24 hour, which is 
the SLA for OpenAI Batch API.
+
+    .. seealso::
+        For more information on how to use this operator, please take a look 
at the guide:
+        :ref:`howto/operator:OpenAITriggerBatchOperator`
+    """
+
+    template_fields: Sequence[str] = ("file_id",)
+
+    def __init__(
+        self,
+        file_id: str,
+        endpoint: Literal["/v1/chat/completions", "/v1/embeddings", 
"/v1/completions"],
+        conn_id: str = OpenAIHook.default_conn_name,
+        deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
+        wait_seconds: float = 3,
+        timeout: float = 24 * 60 * 60,
+        **kwargs: Any,
+    ):
+        super().__init__(**kwargs)
+        self.conn_id = conn_id
+        self.file_id = file_id
+        self.endpoint = endpoint
+        self.deferrable = deferrable
+        self.wait_seconds = wait_seconds
+        self.timeout = timeout
+        self.batch_id: str | None = None
+
+    @cached_property
+    def hook(self) -> OpenAIHook:
+        """Return an instance of the OpenAIHook."""
+        return OpenAIHook(conn_id=self.conn_id)
+
+    def execute(self, context: Context) -> str:
+        batch = self.hook.create_batch(file_id=self.file_id, 
endpoint=self.endpoint)
+        self.batch_id = batch.id
+        if self.deferrable:
+            self.defer(
+                timeout=self.execution_timeout,
+                trigger=OpenAIBatchTrigger(
+                    conn_id=self.conn_id,
+                    batch_id=self.batch_id,
+                    poll_interval=60,
+                    end_time=time.time() + self.timeout,
+                ),
+                method_name="execute_complete",
+            )
+        else:
+            self.log.info("Waiting for batch %s to complete", self.batch_id)
+            self.hook.wait_for_batch(self.batch_id, 
wait_seconds=self.wait_seconds, timeout=self.timeout)

Review Comment:
   Yes, exactly. Thank you for pointing that out. I've added a 
'wait_for_completion' parameter to the operator. If set to False, the operator 
will return the batch immediately, allowing users to proceed with downstream 
tasks without waiting.



##########
airflow/providers/openai/operators/openai.py:
##########
@@ -74,3 +78,86 @@ def execute(self, context: Context) -> list[float]:
         embeddings = self.hook.create_embeddings(self.input_text, 
model=self.model, **self.embedding_kwargs)
         self.log.info("Generated embeddings for %d items", len(embeddings))
         return embeddings
+
+
+class OpenAITriggerBatchOperator(BaseOperator):
+    """
+    Operator that triggers an OpenAI Batch API endpoint and waits for the 
batch to complete.
+
+    :param file_id: Required. The ID of the batch file to trigger.
+    :param endpoint: Required. The OpenAI Batch API endpoint to trigger.
+    :param conn_id: Optional. The OpenAI connection ID to use. Defaults to 
'openai_default'.
+    :param deferrable: Optional. Run operator in the deferrable mode.
+    :param wait_seconds: Optional. Number of seconds between checks. Only used 
when ``deferrable`` is False.
+        Defaults to 3 seconds.
+    :param timeout: Optional. The amount of time, in seconds, to wait for the 
request to complete.
+        Only used when ``deferrable`` is False. Defaults to 24 hour, which is 
the SLA for OpenAI Batch API.
+
+    .. seealso::
+        For more information on how to use this operator, please take a look 
at the guide:
+        :ref:`howto/operator:OpenAITriggerBatchOperator`
+    """
+
+    template_fields: Sequence[str] = ("file_id",)
+
+    def __init__(
+        self,
+        file_id: str,
+        endpoint: Literal["/v1/chat/completions", "/v1/embeddings", 
"/v1/completions"],
+        conn_id: str = OpenAIHook.default_conn_name,
+        deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
+        wait_seconds: float = 3,
+        timeout: float = 24 * 60 * 60,
+        **kwargs: Any,
+    ):
+        super().__init__(**kwargs)
+        self.conn_id = conn_id
+        self.file_id = file_id
+        self.endpoint = endpoint
+        self.deferrable = deferrable
+        self.wait_seconds = wait_seconds
+        self.timeout = timeout
+        self.batch_id: str | None = None
+
+    @cached_property
+    def hook(self) -> OpenAIHook:
+        """Return an instance of the OpenAIHook."""
+        return OpenAIHook(conn_id=self.conn_id)
+
+    def execute(self, context: Context) -> str:
+        batch = self.hook.create_batch(file_id=self.file_id, 
endpoint=self.endpoint)
+        self.batch_id = batch.id
+        if self.deferrable:
+            self.defer(
+                timeout=self.execution_timeout,
+                trigger=OpenAIBatchTrigger(
+                    conn_id=self.conn_id,
+                    batch_id=self.batch_id,
+                    poll_interval=60,
+                    end_time=time.time() + self.timeout,
+                ),
+                method_name="execute_complete",
+            )
+        else:
+            self.log.info("Waiting for batch %s to complete", self.batch_id)
+            self.hook.wait_for_batch(self.batch_id, 
wait_seconds=self.wait_seconds, timeout=self.timeout)

Review Comment:
   Yes, exactly. Thank you for pointing that out. I've added a 
`wait_for_completion` parameter to the operator. If set to False, the operator 
will return the batch immediately, allowing users to proceed with downstream 
tasks without waiting.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to