Lee-W commented on code in PR #38736:
URL: https://github.com/apache/airflow/pull/38736#discussion_r1562745520


##########
airflow/providers/openai/hooks/openai.py:
##########
@@ -77,6 +89,165 @@ def get_conn(self) -> OpenAI:
             **openai_client_kwargs,
         )
 
+    def create_chat_completion(
+        self,
+        messages: list[
+            ChatCompletionSystemMessageParam
+            | ChatCompletionUserMessageParam
+            | ChatCompletionAssistantMessageParam
+            | ChatCompletionToolMessageParam
+            | ChatCompletionFunctionMessageParam
+        ],
+        model: str = "gpt-3.5-turbo",
+        **kwargs: Any,
+    ) -> list[ChatCompletionMessage]:
+        """
+        Create a model response for the given chat conversation and returns a 
list of chat completions.
+
+        :param messages: A list of messages comprising the conversation so far
+        :param model: ID of the model to use
+        """
+        response = self.conn.chat.completions.create(model=model, 
messages=messages, **kwargs)
+        return response.choices
+
+    def create_assistant(self, model: str = "gpt-3.5-turbo", **kwargs: Any) -> 
Assistant:
+        """Create an OpenAI assistant using the given model.
+
+        :param model: The OpenAI model for the assistant to use.
+        """
+        assistant = self.conn.beta.assistants.create(model=model, **kwargs)
+        return assistant
+
+    def get_assistant(self, assistant_id: str) -> Assistant:
+        """
+        Get an OpenAI assistant.
+
+        :param assistant_id: The ID of the assistant to retrieve.
+        """
+        assistant = 
self.conn.beta.assistants.retrieve(assistant_id=assistant_id)
+        return assistant
+
+    def get_assistants(self, **kwargs: Any) -> list[Assistant]:
+        """Get a list of Assistant objects."""
+        assistants = self.conn.beta.assistants.list(**kwargs)
+        return assistants.data
+
+    def get_assistant_by_name(self, assistant_name: str) -> Assistant | None:
+        """Get an OpenAI Assistant object for a given name.
+
+        :param assistant_name: The name of the assistant to retrieve
+        """
+        response = self.get_assistants()
+        for assistant in response:
+            if assistant.name == assistant_name:
+                return assistant
+        return None
+
+    def modify_assistant(self, assistant_id: str, **kwargs: Any) -> Assistant:
+        """Modify an existing Assistant object."""

Review Comment:
   missing `:param assistant_id` in the docstring. let's make it consistent



##########
airflow/providers/openai/hooks/openai.py:
##########
@@ -77,6 +89,165 @@ def get_conn(self) -> OpenAI:
             **openai_client_kwargs,
         )
 
+    def create_chat_completion(
+        self,
+        messages: list[
+            ChatCompletionSystemMessageParam
+            | ChatCompletionUserMessageParam
+            | ChatCompletionAssistantMessageParam
+            | ChatCompletionToolMessageParam
+            | ChatCompletionFunctionMessageParam
+        ],
+        model: str = "gpt-3.5-turbo",
+        **kwargs: Any,
+    ) -> list[ChatCompletionMessage]:
+        """
+        Create a model response for the given chat conversation and returns a 
list of chat completions.
+
+        :param messages: A list of messages comprising the conversation so far
+        :param model: ID of the model to use
+        """
+        response = self.conn.chat.completions.create(model=model, 
messages=messages, **kwargs)
+        return response.choices
+
+    def create_assistant(self, model: str = "gpt-3.5-turbo", **kwargs: Any) -> 
Assistant:
+        """Create an OpenAI assistant using the given model.
+
+        :param model: The OpenAI model for the assistant to use.
+        """
+        assistant = self.conn.beta.assistants.create(model=model, **kwargs)
+        return assistant
+
+    def get_assistant(self, assistant_id: str) -> Assistant:
+        """
+        Get an OpenAI assistant.
+
+        :param assistant_id: The ID of the assistant to retrieve.
+        """
+        assistant = 
self.conn.beta.assistants.retrieve(assistant_id=assistant_id)
+        return assistant
+
+    def get_assistants(self, **kwargs: Any) -> list[Assistant]:
+        """Get a list of Assistant objects."""
+        assistants = self.conn.beta.assistants.list(**kwargs)
+        return assistants.data
+
+    def get_assistant_by_name(self, assistant_name: str) -> Assistant | None:
+        """Get an OpenAI Assistant object for a given name.
+
+        :param assistant_name: The name of the assistant to retrieve
+        """
+        response = self.get_assistants()
+        for assistant in response:
+            if assistant.name == assistant_name:
+                return assistant
+        return None
+
+    def modify_assistant(self, assistant_id: str, **kwargs: Any) -> Assistant:
+        """Modify an existing Assistant object."""
+        assistant = 
self.conn.beta.assistants.update(assistant_id=assistant_id, **kwargs)
+        return assistant
+
+    def delete_assistant(self, assistant_id: str) -> AssistantDeleted:
+        """Delete an OpenAI Assistant for a given ID.
+
+        :param assistant_id: The ID of the assistant to delete.
+        """
+        response = self.conn.beta.assistants.delete(assistant_id=assistant_id)
+        return response
+
+    def create_thread(self, **kwargs: Any) -> Thread:
+        """Create an OpenAI thread."""
+        thread = self.conn.beta.threads.create(**kwargs)
+        return thread
+
+    def modify_thread(self, thread_id: str, metadata: dict) -> Thread:

Review Comment:
   What is metadata? Should it be typed as `dict[str, Any]`?



##########
tests/providers/openai/hooks/test_openai.py:
##########
@@ -56,6 +69,226 @@ def mock_embeddings_response():
     )
 
 
[email protected]
+def mock_completion():
+    return ChatCompletion(
+        id="chatcmpl-123",
+        object="chat.completion",
+        created=1677652288,
+        model=MODEL,
+        choices=[
+            {
+                "index": 0,
+                "message": {
+                    "role": "assistant",
+                    "content": "Hello there, how may I assist you today?",
+                },
+                "logprobs": None,
+                "finish_reason": "stop",
+            }
+        ],
+    )
+
+
[email protected]
+def mock_assistant():
+    return Assistant(
+        id=ASSISTANT_ID,
+        name=ASSISTANT_NAME,
+        object="assistant",
+        created_at=1677652288,
+        model=MODEL,
+        instructions=ASSISTANT_INSTRUCTIONS,
+        tools=[],
+        file_ids=[],
+        metadata={},
+    )
+
+
[email protected]
+def mock_assistant_list(mock_assistant):
+    return SyncCursorPage[Assistant](data=[mock_assistant])
+
+
[email protected]
+def mock_thread():
+    return Thread(id=THREAD_ID, object="thread", created_at=1698984975, 
metadata={})
+
+
[email protected]
+def mock_message():
+    return Message(
+        id=MESSAGE_ID,
+        object="thread.message",
+        created_at=1698984975,
+        thread_id=THREAD_ID,
+        status="completed",
+        role="user",
+        content=[{"type": "text", "text": {"value": "Tell me something 
interesting.", "annotations": []}}],
+        assistant_id=ASSISTANT_ID,
+        run_id=RUN_ID,
+        file_ids=[],
+        metadata={},
+    )
+
+
[email protected]
+def mock_message_list(mock_message):
+    return SyncCursorPage[Message](data=[mock_message])
+
+
[email protected]
+def mock_run():
+    return Run(
+        id=RUN_ID,
+        object="thread.run",
+        created_at=1698107661,
+        assistant_id=ASSISTANT_ID,
+        thread_id=THREAD_ID,
+        status="completed",
+        started_at=1699073476,
+        completed_at=1699073476,
+        model=MODEL,
+        instructions="You are a test assistant.",
+        tools=[],
+        file_ids=[],
+        metadata={},
+    )
+
+
[email protected]
+def mock_run_list(mock_run):
+    return SyncCursorPage[Run](data=[mock_run])
+
+
+def test_create_chat_completion(mock_openai_hook, mock_completion):
+    messages = [
+        {"role": "system", "content": "You are a helpful assistant."},
+        {"role": "user", "content": "Hello!"},
+    ]
+
+    mock_openai_hook.conn.chat.completions.create.return_value = 
mock_completion
+    completion = mock_openai_hook.create_chat_completion(model=MODEL, 
messages=messages)
+    choice = completion[0]
+    assert choice.message.content == "Hello there, how may I assist you today?"
+
+
+def test_create_assistant(mock_openai_hook, mock_assistant):
+    mock_openai_hook.conn.beta.assistants.create.return_value = mock_assistant
+    assistant = mock_openai_hook.create_assistant(
+        name=ASSISTANT_NAME, model=MODEL, instructions=ASSISTANT_INSTRUCTIONS
+    )
+    assert assistant.name == ASSISTANT_NAME
+    assert assistant.model == MODEL
+    assert assistant.instructions == ASSISTANT_INSTRUCTIONS
+
+
+def test_get_assistant(mock_openai_hook, mock_assistant):
+    mock_openai_hook.conn.beta.assistants.retrieve.return_value = 
mock_assistant
+    assistant = mock_openai_hook.get_assistant(assistant_id=ASSISTANT_ID)
+    assert assistant.name == ASSISTANT_NAME
+    assert assistant.model == MODEL
+    assert assistant.instructions == ASSISTANT_INSTRUCTIONS
+
+
+def test_get_assistants(mock_openai_hook, mock_assistant_list):
+    mock_openai_hook.conn.beta.assistants.list.return_value = 
mock_assistant_list
+    assistants = mock_openai_hook.get_assistants()
+    assert isinstance(assistants, list)
+
+
+def test_get_assistant_with_name(mock_openai_hook, mock_assistant_list):

Review Comment:
   ```suggestion
   def test_get_assistant_by_name(mock_openai_hook, mock_assistant_list):
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to