imbajin commented on code in PR #178:
URL: 
https://github.com/apache/incubator-hugegraph-ai/pull/178#discussion_r1970891251


##########
hugegraph-llm/src/hugegraph_llm/models/embeddings/litellm.py:
##########
@@ -0,0 +1,108 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import List, Optional
+import numpy as np
+
+from litellm import embedding, RateLimitError, APIError, APIConnectionError, 
aembedding
+from tenacity import (
+    retry,
+    stop_after_attempt,
+    wait_exponential,
+    retry_if_exception_type,
+)
+
+from hugegraph_llm.models.embeddings.base import BaseEmbedding
+from hugegraph_llm.utils.log import log
+
+
+class LiteLLMEmbedding(BaseEmbedding):
+    """Wrapper for LiteLLM Embedding that supports multiple LLM providers."""
+
+    def __init__(
+        self,
+        api_key: Optional[str] = None,
+        api_base: Optional[str] = None,
+        model_name: str = "text-embedding-3-small",  # Can be any embedding 
model supported by LiteLLM
+    ) -> None:
+        self.api_key = api_key
+        self.api_base = api_base
+        self.model = model_name
+
+    @retry(
+        stop=stop_after_attempt(3),
+        wait=wait_exponential(multiplier=1, min=4, max=10),
+        retry=retry_if_exception_type((RateLimitError, APIConnectionError, 
APIError)),
+    )
+    def get_text_embedding(self, text: str) -> List[float]:
+        """Get embedding for a single text."""
+        try:
+            response = embedding(
+                model=self.model,
+                input=text,
+                api_key=self.api_key,
+                api_base=self.api_base,
+            )
+            log.info("Token usage: %s", response.usage)
+            return response.data[0]["embedding"]
+        except Exception as e:
+            log.error("Error in LiteLLM embedding call: %s", e)
+            # Return zero vector as fallback
+            return [0.0] * 1536  # Most common embedding dimension
+
+    def get_texts_embeddings(self, texts: List[str]) -> List[List[float]]:
+        """Get embeddings for multiple texts."""
+        try:
+            response = embedding(
+                model=self.model,
+                input=texts,
+                api_key=self.api_key,
+                api_base=self.api_base,
+            )
+            log.info("Token usage: %s", response.usage)
+            return [data["embedding"] for data in response.data]
+        except Exception as e:
+            log.error("Error in LiteLLM batch embedding call: %s", e)
+            # Return zero vectors as fallback
+            return [[0.0] * 1536 for _ in texts]  # Most common embedding 
dimension
+
+    async def async_get_text_embedding(self, text: str) -> List[float]:
+        """Get embedding for a single text asynchronously."""
+        try:
+            response = await aembedding(
+                model=self.model,
+                input=text,
+                api_key=self.api_key,
+                api_base=self.api_base,
+            )
+            log.info("Token usage: %s", response.usage)
+            return response.data[0]["embedding"]
+        except Exception as e:
+            log.error("Error in async LiteLLM embedding call: %s", e)
+            # Return zero vector as fallback
+            return [0.0] * 1536  # Most common embedding dimension
+
+    def similarity(self, embedding1: List[float], embedding2: List[float]) -> 
float:
+        """Calculate cosine similarity between two embeddings."""
+        # Convert to numpy arrays
+        emb1 = np.array(embedding1)
+        emb2 = np.array(embedding2)
+        # Handle zero vectors
+        if np.all(emb1 == 0) or np.all(emb2 == 0):
+            return 0.0
+        # Calculate cosine similarity
+        return float(np.dot(emb1, emb2) / (np.linalg.norm(emb1) * 
np.linalg.norm(emb2))) 

Review Comment:
   ```suggestion
           return float(np.dot(emb1, emb2) / (np.linalg.norm(emb1) * 
np.linalg.norm(emb2))) 
   
   ```



##########
hugegraph-llm/src/hugegraph_llm/models/llms/litellm.py:
##########
@@ -0,0 +1,158 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import Callable, List, Optional, Dict, Any
+
+import tiktoken
+from litellm import completion, acompletion
+from litellm.exceptions import RateLimitError, BudgetExceededError, APIError
+
+from tenacity import (
+    retry,
+    stop_after_attempt,
+    wait_exponential,
+    retry_if_exception_type,
+)
+
+from hugegraph_llm.models.llms.base import BaseLLM
+from hugegraph_llm.utils.log import log
+
+
+class LiteLLMClient(BaseLLM):
+    """Wrapper for LiteLLM Client that supports multiple LLM providers."""
+
+    def __init__(
+        self,
+        api_key: Optional[str] = None,
+        api_base: Optional[str] = None,
+        model_name: str = "gpt-4",  # Can be any model supported by LiteLLM
+        max_tokens: int = 4096,
+        temperature: float = 0.0,
+    ) -> None:
+        self.api_key = api_key
+        self.api_base = api_base
+        self.model = model_name
+        self.max_tokens = max_tokens
+        self.temperature = temperature
+
+    @retry(
+        stop=stop_after_attempt(3),
+        wait=wait_exponential(multiplier=1, min=4, max=10),
+        retry=retry_if_exception_type((RateLimitError, BudgetExceededError, 
APIError))
+    )
+    def generate(
+        self,
+        messages: Optional[List[Dict[str, Any]]] = None,
+        prompt: Optional[str] = None,
+    ) -> str:
+        """Generate a response to the query messages/prompt."""
+        if messages is None:
+            assert prompt is not None, "Messages or prompt must be provided."
+            messages = [{"role": "user", "content": prompt}]
+        try:
+            print("base_url:" + self.api_base)
+            response = completion(
+                model=self.model,
+                messages=messages,
+                temperature=self.temperature,
+                max_tokens=self.max_tokens,
+                api_key=self.api_key,
+                base_url=self.api_base,
+            )
+            log.info("Token usage: %s", response.usage)
+            return response.choices[0].message.content
+        except Exception as e:
+            log.error("Error in LiteLLM call: %s", e)
+            return f"Error: {str(e)}"
+
+    @retry(
+        stop=stop_after_attempt(3),
+        wait=wait_exponential(multiplier=1, min=4, max=10),
+        retry=retry_if_exception_type((RateLimitError, BudgetExceededError, 
APIError))
+    )
+    async def agenerate(
+            self,
+            messages: Optional[List[Dict[str, Any]]] = None,
+            prompt: Optional[str] = None,
+    ) -> str:
+        """Generate a response to the query messages/prompt asynchronously."""
+        if messages is None:
+            assert prompt is not None, "Messages or prompt must be provided."
+            messages = [{"role": "user", "content": prompt}]
+        try:
+            response = await acompletion(
+                model=self.model,
+                messages=messages,
+                temperature=self.temperature,
+                max_tokens=self.max_tokens,
+                api_key=self.api_key,
+                base_url=self.api_base,
+            )
+            log.info("Token usage: %s", response.usage)
+            return response.choices[0].message.content
+        except Exception as e:
+            log.error("Error in async LiteLLM call: %s", e)
+            return f"Error: {str(e)}"
+
+    def generate_streaming(
+        self,
+        messages: Optional[List[Dict[str, Any]]] = None,
+        prompt: Optional[str] = None,
+        on_token_callback: Callable = None,
+    ) -> str:
+        """Generate a response to the query messages/prompt in streaming 
mode."""
+        if messages is None:
+            assert prompt is not None, "Messages or prompt must be provided."
+            messages = [{"role": "user", "content": prompt}]
+        try:
+            response = completion(
+                model=self.model,
+                messages=messages,
+                temperature=self.temperature,
+                max_tokens=self.max_tokens,
+                api_key=self.api_key,
+                base_url=self.api_base,
+                stream=True,
+            )
+            result = ""
+            for chunk in response:
+                if chunk.choices[0].delta.content:
+                    result += chunk.choices[0].delta.content
+                if on_token_callback:
+                    on_token_callback(chunk)
+            return result
+        except Exception as e:
+            log.error("Error in streaming LiteLLM call: %s", e)
+            return f"Error: {str(e)}"
+
+    def num_tokens_from_string(self, string: str) -> int:
+        """Get token count from string."""
+        try:
+            encoding = tiktoken.encoding_for_model(self.model)
+            num_tokens = len(encoding.encode(string))
+            return num_tokens
+        except Exception:
+            # Fallback for models not supported by tiktoken
+            # Rough estimate: 1 token ≈ 4 characters
+            return len(string) // 4
+
+    def max_allowed_token_length(self) -> int:
+        """Get max-allowed token length based on the model."""
+        return 4096  # Default to 4096 if model not found
+
+    def get_llm_type(self) -> str:
+        return "litellm" 

Review Comment:
   ```suggestion
           return "litellm" 
   
   ```



##########
hugegraph-llm/.gitignore:
##########
@@ -1,3 +1,5 @@
 src/hugegraph_llm/resources/demo/questions_answers.xlsx
 src/hugegraph_llm/resources/demo/questions.xlsx
 src/hugegraph_llm/resources/backup-graph-data-4020/
+
+uv.lock

Review Comment:
   ```suggestion
   uv.lock
   
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@hugegraph.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscr...@hugegraph.apache.org
For additional commands, e-mail: issues-h...@hugegraph.apache.org

Reply via email to