This is an automated email from the ASF dual-hosted git repository.
jin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git
The following commit(s) were added to refs/heads/main by this push:
new f252903 feat(llm): update openai dependency to 1.47 (#88)
f252903 is described below
commit f252903628437958d60d660850a1790d593a4bd4
Author: Liu Jiajun <[email protected]>
AuthorDate: Sat Sep 28 22:14:55 2024 +0800
feat(llm): update openai dependency to 1.47 (#88)
---
hugegraph-llm/requirements.txt | 2 +-
.../src/hugegraph_llm/models/embeddings/openai.py | 12 +++++------
.../src/hugegraph_llm/models/llms/openai.py | 24 +++++++++++-----------
3 files changed, 18 insertions(+), 20 deletions(-)
diff --git a/hugegraph-llm/requirements.txt b/hugegraph-llm/requirements.txt
index d13e015..e10cb22 100644
--- a/hugegraph-llm/requirements.txt
+++ b/hugegraph-llm/requirements.txt
@@ -1,4 +1,4 @@
-openai~=0.28.1
+openai~=1.47.1
ollama~=0.2.1
qianfan~=0.3.18
retry~=0.9.2
diff --git a/hugegraph-llm/src/hugegraph_llm/models/embeddings/openai.py
b/hugegraph-llm/src/hugegraph_llm/models/embeddings/openai.py
index 2a092e7..1665817 100644
--- a/hugegraph-llm/src/hugegraph_llm/models/embeddings/openai.py
+++ b/hugegraph-llm/src/hugegraph_llm/models/embeddings/openai.py
@@ -18,8 +18,7 @@
from typing import Optional, List
-import os
-import openai
+from openai import OpenAI, AsyncOpenAI
class OpenAIEmbedding:
@@ -29,17 +28,16 @@ class OpenAIEmbedding:
api_key: Optional[str] = None,
api_base: Optional[str] = None
):
- openai.api_key = api_key or os.getenv("OPENAI_API_KEY")
- openai.api_base = api_base or os.getenv("OPENAI_API_BASE")
+ self.client = OpenAI(api_key=api_key, base_url=api_base)
+ self.aclient = AsyncOpenAI(api_key=api_key, base_url=api_base)
self.embedding_model_name = model_name
- self.client = openai.Embedding()
def get_text_embedding(self, text: str) -> List[float]:
"""Comment"""
- response = self.client.create(input=text,
model=self.embedding_model_name)
+ response = self.client.embeddings.create(input=text,
model=self.embedding_model_name)
return response.data[0].embedding
async def async_get_text_embedding(self, text: str) -> List[float]:
"""Comment"""
- response = await self.client.acreate(input=text,
model=self.embedding_model_name)
+ response = await self.aclient.embeddings.create(input=text,
model=self.embedding_model_name)
return response.data[0].embedding
diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py
b/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py
index bfdb83b..8a6ab8f 100644
--- a/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py
+++ b/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py
@@ -16,10 +16,10 @@
# under the License.
import json
-import os
from typing import Callable, List, Optional, Dict, Any
import openai
+from openai import OpenAI, AsyncOpenAI
import tiktoken
from retry import retry
@@ -38,8 +38,8 @@ class OpenAIClient(BaseLLM):
max_tokens: int = 4096,
temperature: float = 0.0,
) -> None:
- openai.api_key = api_key or os.getenv("OPENAI_API_KEY")
- openai.api_base = api_base or os.getenv("OPENAI_API_BASE")
+ self.client = OpenAI(api_key=api_key, base_url=api_base)
+ self.aclient = AsyncOpenAI(api_key=api_key, base_url=api_base)
self.model = model_name
self.max_tokens = max_tokens
self.temperature = temperature
@@ -55,20 +55,20 @@ class OpenAIClient(BaseLLM):
assert prompt is not None, "Messages or prompt must be provided."
messages = [{"role": "user", "content": prompt}]
try:
- completions = openai.ChatCompletion.create(
+ completions = self.client.chat.completions.create(
model=self.model,
temperature=self.temperature,
max_tokens=self.max_tokens,
messages=messages,
)
- log.info("Token usage: %s", json.dumps(completions.usage))
+ log.info("Token usage: %s", completions.usage.model_dump_json())
return completions.choices[0].message.content
# catch context length / do not retry
- except openai.error.InvalidRequestError as e:
+ except openai.BadRequestError as e:
log.critical("Fatal: %s", e)
return str(f"Error: {e}")
# catch authorization errors / do not retry
- except openai.error.AuthenticationError:
+ except openai.AuthenticationError:
log.critical("The provided OpenAI API key is invalid")
return "Error: The provided OpenAI API key is invalid"
except Exception as e:
@@ -86,20 +86,20 @@ class OpenAIClient(BaseLLM):
assert prompt is not None, "Messages or prompt must be provided."
messages = [{"role": "user", "content": prompt}]
try:
- completions = await openai.ChatCompletion.acreate(
+ completions = await self.aclient.chat.completions.create(
model=self.model,
temperature=self.temperature,
max_tokens=self.max_tokens,
messages=messages,
)
- log.info("Token usage: %s", json.dumps(completions.usage))
+ log.info("Token usage: %s", completions.usage.model_dump_json())
return completions.choices[0].message.content
# catch context length / do not retry
- except openai.error.InvalidRequestError as e:
+ except openai.BadRequestError as e:
log.critical("Fatal: %s", e)
return str(f"Error: {e}")
# catch authorization errors / do not retry
- except openai.error.AuthenticationError:
+ except openai.AuthenticationError:
log.critical("The provided OpenAI API key is invalid")
return "Error: The provided OpenAI API key is invalid"
except Exception as e:
@@ -116,7 +116,7 @@ class OpenAIClient(BaseLLM):
if messages is None:
assert prompt is not None, "Messages or prompt must be provided."
messages = [{"role": "user", "content": prompt}]
- completions = openai.ChatCompletion.create(
+ completions = self.client.chat.completions.create(
model=self.model,
temperature=self.temperature,
max_tokens=self.max_tokens,