This is an automated email from the ASF dual-hosted git repository.

jin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git


The following commit(s) were added to refs/heads/main by this push:
     new aac82c0  feat(llm): record token cost & qps (#72)
aac82c0 is described below

commit aac82c049ab73d854f7a77baee8b8ece84fa6ea3
Author: chenzihong <[email protected]>
AuthorDate: Fri Aug 30 17:43:12 2024 +0800

    feat(llm): record token cost & qps (#72)
    
    Using decorator way to record LLM token & time cost
    
    TODO: handle QPS in another PR
    
    ---------
    
    Co-authored-by: imbajin <[email protected]>
---
 .../src/hugegraph_llm/api/models/rag_requests.py   |  1 +
 hugegraph-llm/src/hugegraph_llm/api/rag_api.py     |  2 +-
 .../src/hugegraph_llm/demo/rag_web_demo.py         |  5 +-
 .../src/hugegraph_llm/middleware/__init__.py       | 16 +++++
 .../src/hugegraph_llm/middleware/middleware.py     | 45 +++++++++++++
 .../src/hugegraph_llm/models/llms/ollama.py        | 16 ++++-
 .../src/hugegraph_llm/models/llms/openai.py        |  7 +-
 .../src/hugegraph_llm/models/llms/qianfan.py       |  4 ++
 .../src/hugegraph_llm/operators/graph_rag_task.py  | 16 ++---
 .../operators/gremlin_generate_task.py             | 16 ++---
 .../operators/kg_construction_task.py              | 26 ++++----
 .../operators/llm_op/answer_synthesize.py          |  1 +
 .../hugegraph_llm/operators/llm_op/info_extract.py |  8 +--
 .../operators/llm_op/property_graph_extract.py     |  3 +-
 .../src/hugegraph_llm/utils/decorators.py          | 78 ++++++++++++++++++++++
 15 files changed, 202 insertions(+), 42 deletions(-)

diff --git a/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py 
b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
index 47610f5..ce0eaa7 100644
--- a/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
+++ b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
@@ -26,6 +26,7 @@ class RAGRequest(BaseModel):
     vector_only: Optional[bool] = False
     graph_only: Optional[bool] = False
     graph_vector: Optional[bool] = False
+    answer_prompt: Optional[str] = None
 
 
 class GraphConfigRequest(BaseModel):
diff --git a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py 
b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
index e583619..923e70f 100644
--- a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
+++ b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
@@ -26,7 +26,7 @@ from hugegraph_llm.config import settings
 def rag_http_api(router: APIRouter, rag_answer_func, apply_graph_conf, 
apply_llm_conf, apply_embedding_conf):
     @router.post("/rag", status_code=status.HTTP_200_OK)
     def rag_answer_api(req: RAGRequest):
-        result = rag_answer_func(req.query, req.raw_llm, req.vector_only, 
req.graph_only, req.graph_vector)
+        result = rag_answer_func(req.query, req.raw_llm, req.vector_only, 
req.graph_only, req.graph_vector, req.answer_prompt)
         return {
             key: value
             for key, value in zip(["raw_llm", "vector_only", "graph_only", 
"graph_vector"], result)
diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py 
b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py
index f42d6e7..f065b6f 100644
--- a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py
+++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py
@@ -144,6 +144,7 @@ def build_kg(  # pylint: disable=too-many-branches
         builder.fetch_graph_data()
     else:
         builder.extract_info(example_prompt, "property_graph")
+
     # "Test Mode", "Import Mode", "Clear and Import", "Rebuild Vector"
     if build_mode != BuildMode.TEST_MODE.value:
         builder.build_vector_index()
@@ -151,7 +152,7 @@ def build_kg(  # pylint: disable=too-many-branches
         builder.commit_to_hugegraph()
     if build_mode != BuildMode.TEST_MODE.value:
         builder.build_vertex_id_semantic_index()
-    log.debug(builder.operators)
+    log.warning("Current building mode: [%s]", build_mode)
     try:
         context = builder.run()
         return str(context)
@@ -502,8 +503,10 @@ if __name__ == "__main__":
     auth_enabled = os.getenv("ENABLE_LOGIN", "False").lower() == "true"
     log.info("Authentication is %s.", "enabled" if auth_enabled else 
"disabled")
     # TODO: support multi-user login when need
+
     app = gr.mount_gradio_app(app, hugegraph_llm, path="/", auth=("rag", 
os.getenv("TOKEN")) if auth_enabled else None)
 
     # TODO: we can't use reload now due to the config 'app' of uvicorn.run
     # ❎:f'{__name__}:app' / rag_web_demo:app / 
hugegraph_llm.demo.rag_web_demo:app
+    # TODO: merge unicorn log to avoid duplicate log output (should be 
unified/fixed later)
     uvicorn.run(app, host=args.host, port=args.port, reload=False)
diff --git a/hugegraph-llm/src/hugegraph_llm/middleware/__init__.py 
b/hugegraph-llm/src/hugegraph_llm/middleware/__init__.py
new file mode 100644
index 0000000..13a8339
--- /dev/null
+++ b/hugegraph-llm/src/hugegraph_llm/middleware/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/hugegraph-llm/src/hugegraph_llm/middleware/middleware.py 
b/hugegraph-llm/src/hugegraph_llm/middleware/middleware.py
new file mode 100644
index 0000000..2ddaa13
--- /dev/null
+++ b/hugegraph-llm/src/hugegraph_llm/middleware/middleware.py
@@ -0,0 +1,45 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import time
+
+from fastapi import Request
+from starlette.middleware.base import BaseHTTPMiddleware
+
+from hugegraph_llm.utils.log import log
+
+
+# TODO: we could use middleware(AOP) in the future (dig out the lifecycle of 
gradio & fastapi)
+class UseTimeMiddleware(BaseHTTPMiddleware):
+    """Middleware to add process time to response headers"""
+    def __init__(self, app):
+        super().__init__(app)
+
+    async def dispatch(self, request: Request, call_next):
+        # TODO: handle time record for async task pool in gradio
+        start_time = time.perf_counter()
+        response = await call_next(request)
+        process_time = (time.perf_counter() - start_time) * 1000 # ms
+        unit = "ms"
+        if process_time > 1000:
+            process_time /= 1000
+            unit = "s"
+
+        response.headers["X-Process-Time"] = f"{process_time:.2f} {unit}"
+        log.info("Request process time: %.2f ms, code=%d", process_time, 
response.status_code)
+        log.info(f"{request.method} - Args: {request.query_params}, IP: 
{request.client.host}, URL: {request.url}")
+        return response
diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/ollama.py 
b/hugegraph-llm/src/hugegraph_llm/models/llms/ollama.py
index 5965599..dfd4669 100644
--- a/hugegraph-llm/src/hugegraph_llm/models/llms/ollama.py
+++ b/hugegraph-llm/src/hugegraph_llm/models/llms/ollama.py
@@ -21,7 +21,9 @@ from typing import Any, List, Optional, Callable, Dict
 import ollama
 from retry import retry
 
-from .base import BaseLLM
+from hugegraph_llm.models.llms.base import BaseLLM
+from hugegraph_llm.utils.log import log
+import json
 
 
 class OllamaClient(BaseLLM):
@@ -46,6 +48,12 @@ class OllamaClient(BaseLLM):
                 model=self.model,
                 messages=messages,
             )
+            usage = {
+                "prompt_tokens": response['prompt_eval_count'],
+                "completion_tokens": response['eval_count'],
+                "total_tokens": response['prompt_eval_count'] + 
response['eval_count'],
+            }
+            log.info("Token usage: %s", json.dumps(usage))
             return response["message"]["content"]
         except Exception as e:
             print(f"Retrying LLM call {e}")
@@ -66,6 +74,12 @@ class OllamaClient(BaseLLM):
                 model=self.model,
                 messages=messages,
             )
+            usage = {
+                "prompt_tokens": response['prompt_eval_count'],
+                "completion_tokens": response['eval_count'],
+                "total_tokens": response['prompt_eval_count'] + 
response['eval_count'],
+            }
+            log.info("Token usage: %s", json.dumps(usage))
             return response["message"]["content"]
         except Exception as e:
             print(f"Retrying LLM call {e}")
diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py 
b/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py
index 36cb11d..9d1f8ed 100644
--- a/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py
+++ b/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py
@@ -14,10 +14,11 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-
-
+import json
 import os
+import re
 from typing import Callable, List, Optional, Dict, Any
+
 import openai
 import tiktoken
 from retry import retry
@@ -60,6 +61,7 @@ class OpenAIChat(BaseLLM):
                 max_tokens=self.max_tokens,
                 messages=messages,
             )
+            log.info("Token usage: %s", json.dumps(completions.usage))
             return completions.choices[0].message.content
         # catch context length / do not retry
         except openai.error.InvalidRequestError as e:
@@ -90,6 +92,7 @@ class OpenAIChat(BaseLLM):
                 max_tokens=self.max_tokens,
                 messages=messages,
             )
+            log.info("Token usage: %s", json.dumps(completions.usage))
             return completions.choices[0].message.content
         # catch context length / do not retry
         except openai.error.InvalidRequestError as e:
diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/qianfan.py 
b/hugegraph-llm/src/hugegraph_llm/models/llms/qianfan.py
index 25d5e21..c03dfb2 100644
--- a/hugegraph-llm/src/hugegraph_llm/models/llms/qianfan.py
+++ b/hugegraph-llm/src/hugegraph_llm/models/llms/qianfan.py
@@ -22,6 +22,8 @@ import qianfan
 from retry import retry
 
 from hugegraph_llm.models.llms.base import BaseLLM
+from hugegraph_llm.utils.log import log
+import json
 
 
 class QianfanClient(BaseLLM):
@@ -47,6 +49,7 @@ class QianfanClient(BaseLLM):
             raise Exception(
                 f"Request failed with code {response.code}, message: 
{response.body['error_msg']}"
             )
+        log.info("Token usage: %s", json.dumps(response.body["usage"]))
         return response.body["result"]
 
     @retry(tries=3, delay=1)
@@ -64,6 +67,7 @@ class QianfanClient(BaseLLM):
             raise Exception(
                 f"Request failed with code {response.code}, message: 
{response.body['error_msg']}"
             )
+        log.info("Token usage: %s", json.dumps(response.body["usage"]))
         return response.body["result"]
 
     def generate_streaming(
diff --git a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py 
b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
index 444ff45..9bef047 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
@@ -16,7 +16,6 @@
 # under the License.
 
 
-import time
 from typing import Dict, Any, Optional, List, Literal
 
 from hugegraph_llm.models.embeddings.base import BaseEmbedding
@@ -31,7 +30,7 @@ from hugegraph_llm.operators.index_op.semantic_id_query 
import SemanticIdQuery
 from hugegraph_llm.operators.index_op.vector_index_query import 
VectorIndexQuery
 from hugegraph_llm.operators.llm_op.answer_synthesize import AnswerSynthesize
 from hugegraph_llm.operators.llm_op.keyword_extract import KeywordExtract
-from hugegraph_llm.utils.log import log
+from hugegraph_llm.utils.decorators import log_time, log_operator_time
 
 
 class RAGPipeline:
@@ -210,6 +209,7 @@ class RAGPipeline:
         self._operators.append(PrintResult())
         return self
 
+    @log_time("total time")
     def run(self, **kwargs) -> Dict[str, Any]:
         """
         Execute all operators in the pipeline in sequence.
@@ -222,11 +222,11 @@ class RAGPipeline:
 
         context = kwargs
         context["llm"] = self._llm
+
         for operator in self._operators:
-            log.debug("Running operator: %s", operator.__class__.__name__)
-            start = time.time()
-            context = operator.run(context)
-            log.debug("Operator %s finished in %s seconds", 
operator.__class__.__name__,
-                      time.time() - start)
-            log.debug("Context:\n%s", context)
+            context = self._run_operator(operator, context)
         return context
+
+    @log_operator_time
+    def _run_operator(self, operator, context):
+        return operator.run(context)
diff --git a/hugegraph-llm/src/hugegraph_llm/operators/gremlin_generate_task.py 
b/hugegraph-llm/src/hugegraph_llm/operators/gremlin_generate_task.py
index ca55692..557f7b7 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/gremlin_generate_task.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/gremlin_generate_task.py
@@ -16,15 +16,13 @@
 # under the License.
 
 
-import time
-
 from hugegraph_llm.models.embeddings.base import BaseEmbedding
 from hugegraph_llm.models.llms.base import BaseLLM
 from hugegraph_llm.operators.common_op.print_result import PrintResult
 from hugegraph_llm.operators.index_op.build_gremlin_example_index import 
BuildGremlinExampleIndex
 from hugegraph_llm.operators.index_op.gremlin_example_index_query import 
GremlinExampleIndexQuery
 from hugegraph_llm.operators.llm_op.gremlin_generate import GremlinGenerate
-from hugegraph_llm.utils.log import log
+from hugegraph_llm.utils.decorators import log_time, log_operator_time
 
 
 class GremlinGenerator:
@@ -51,13 +49,13 @@ class GremlinGenerator:
         self.operators.append(PrintResult())
         return self
 
+    @log_time("total time")
     def run(self):
         context = {}
         for operator in self.operators:
-            log.debug("Running operator: %s", operator.__class__.__name__)
-            start = time.time()
-            context = operator.run(context)
-            log.debug("Operator %s finished in %s seconds", 
operator.__class__.__name__,
-                      time.time() - start)
-            log.debug("Context:\n%s", context)
+            context = self._run_operator(operator, context)
         return context
+
+    @log_operator_time
+    def _run_operator(self, operator, context):
+        return operator.run(context)
diff --git a/hugegraph-llm/src/hugegraph_llm/operators/kg_construction_task.py 
b/hugegraph-llm/src/hugegraph_llm/operators/kg_construction_task.py
index 876d1c8..f30a0e4 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/kg_construction_task.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/kg_construction_task.py
@@ -16,25 +16,23 @@
 # under the License.
 
 
-import time
 from typing import Dict, Any, Optional, Literal, Union, List
 
-from pyhugegraph.client import PyHugeClient
-
-from hugegraph_llm.models.llms.base import BaseLLM
 from hugegraph_llm.models.embeddings.base import BaseEmbedding
+from hugegraph_llm.models.llms.base import BaseLLM
 from hugegraph_llm.operators.common_op.check_schema import CheckSchema
 from hugegraph_llm.operators.common_op.print_result import PrintResult
-from hugegraph_llm.operators.hugegraph_op.fetch_graph_data import 
FetchGraphData
-from hugegraph_llm.operators.index_op.build_semantic_index import 
BuildSemanticIndex
-from hugegraph_llm.operators.index_op.build_vector_index import 
BuildVectorIndex
 from hugegraph_llm.operators.document_op.chunk_split import ChunkSplit
 from hugegraph_llm.operators.hugegraph_op.commit_to_hugegraph import CommitToKg
+from hugegraph_llm.operators.hugegraph_op.fetch_graph_data import 
FetchGraphData
 from hugegraph_llm.operators.hugegraph_op.schema_manager import SchemaManager
+from hugegraph_llm.operators.index_op.build_semantic_index import 
BuildSemanticIndex
+from hugegraph_llm.operators.index_op.build_vector_index import 
BuildVectorIndex
 from hugegraph_llm.operators.llm_op.disambiguate_data import DisambiguateData
 from hugegraph_llm.operators.llm_op.info_extract import InfoExtract
 from hugegraph_llm.operators.llm_op.property_graph_extract import 
PropertyGraphExtract
-from hugegraph_llm.utils.log import log
+from hugegraph_llm.utils.decorators import log_time, log_operator_time
+from pyhugegraph.client import PyHugeClient
 
 
 class KgBuilder:
@@ -98,13 +96,13 @@ class KgBuilder:
         self.operators.append(PrintResult())
         return self
 
+    @log_time("total time")
     def run(self) -> Dict[str, Any]:
         context = None
         for operator in self.operators:
-            log.debug("Running operator: %s", operator.__class__.__name__)
-            start = time.time()
-            context = operator.run(context)
-            log.debug("Operator %s finished in %s seconds", 
operator.__class__.__name__,
-                      time.time() - start)
-            log.debug("Context:\n%s", context)
+            context = self._run_operator(operator, context)
         return context
+
+    @log_operator_time
+    def _run_operator(self, operator, context):
+        return operator.run(context)
diff --git 
a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py 
b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py
index a52fdb6..7032272 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py
@@ -154,6 +154,7 @@ class AnswerSynthesize:
             task_cache["graph_vector_task"] = asyncio.create_task(
                 self._llm.agenerate(prompt=prompt)
             )
+        # TODO: use log.debug instead of print
         if task_cache.get("raw_task"):
             response = await task_cache["raw_task"]
             context["raw_answer"] = response
diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/info_extract.py 
b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/info_extract.py
index 1424143..cf7d250 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/info_extract.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/info_extract.py
@@ -18,8 +18,8 @@
 import re
 from typing import List, Any, Dict, Optional
 
-from hugegraph_llm.models.llms.base import BaseLLM
 from hugegraph_llm.document.chunk_split import ChunkSplitter
+from hugegraph_llm.models.llms.base import BaseLLM
 from hugegraph_llm.utils.log import log
 
 SCHEMA_EXAMPLE_PROMPT = """## Main Task
@@ -152,8 +152,7 @@ class InfoExtract:
 
         for sentence in chunks:
             proceeded_chunk = self.extract_triples_by_llm(schema, sentence)
-            log.debug("[LLM] %s input: %s \n output:%s", 
self.__class__.__name__,
-                      sentence, proceeded_chunk)
+            log.debug("[Legacy] %s input: %s \n output:%s", 
self.__class__.__name__, sentence, proceeded_chunk)
             if schema:
                 extract_triples_by_regex_with_schema(schema, proceeded_chunk, 
context)
             else:
@@ -166,7 +165,8 @@ class InfoExtract:
             prompt = self.example_prompt + prompt
         return self.llm.generate(prompt=prompt)
 
-    def valid(self, element_id: str, max_length: int = 128):
+    # TODO: make 'max_length' be a configurable param in 
settings.py/settings.cfg
+    def valid(self, element_id: str, max_length: int = 256):
         if len(element_id.encode("utf-8")) >= max_length:
             log.warning("Filter out GraphElementID too long: %s", element_id)
             return False
diff --git 
a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/property_graph_extract.py 
b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/property_graph_extract.py
index e4d6dd2..cfc54b6 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/property_graph_extract.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/property_graph_extract.py
@@ -102,8 +102,7 @@ class PropertyGraphExtract:
         items = []
         for chunk in chunks:
             proceeded_chunk = self.extract_property_graph_by_llm(schema, chunk)
-            log.debug("[LLM] %s input: %s \n output:%s", 
self.__class__.__name__, chunk,
-                      proceeded_chunk)
+            log.debug("[LLM] %s input: %s \n output:%s", 
self.__class__.__name__, chunk, proceeded_chunk)
             items.extend(self._extract_and_filter_label(schema, 
proceeded_chunk))
         items = self.filter_item(schema, items)
         for item in items:
diff --git a/hugegraph-llm/src/hugegraph_llm/utils/decorators.py 
b/hugegraph-llm/src/hugegraph_llm/utils/decorators.py
new file mode 100644
index 0000000..6da2896
--- /dev/null
+++ b/hugegraph-llm/src/hugegraph_llm/utils/decorators.py
@@ -0,0 +1,78 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import time
+import asyncio
+from functools import wraps
+from typing import Optional, Any, Callable
+
+from hugegraph_llm.utils.log import log
+
+
+def log_elapsed_time(start_time: float, func: Callable, args: tuple, msg: 
Optional[str]):
+    elapse_time = time.perf_counter() - start_time
+    unit = "s"
+    if elapse_time < 1:
+        elapse_time *= 1000
+        unit = "ms"
+
+    class_name = args[0].__class__.__name__ if args else ""
+    message = f"{class_name} {msg or f'func {func.__name__}()'}"
+    log.info("%s took %.2f %s", message, elapse_time, unit)
+
+
+def log_time(msg: Optional[str] = "") -> Callable:
+    def decorator(func: Callable) -> Callable:
+        @wraps(func)
+        async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
+            start_time = time.perf_counter()
+            result = await func(*args, **kwargs)
+            log_elapsed_time(start_time, func, args, msg)
+            return result
+
+        @wraps(func)
+        def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
+            start_time = time.perf_counter()
+            result = func(*args, **kwargs)
+            log_elapsed_time(start_time, func, args, msg)
+            return result
+
+        if asyncio.iscoroutinefunction(func):
+            return async_wrapper
+        else:
+            return sync_wrapper
+
+    # handle "@log_time" usage -> better to use "@log_time()" instead
+    if callable(msg):
+        return decorator(msg)
+    return decorator
+
+
+def log_operator_time(func: Callable) -> Callable:
+    @wraps(func)
+    def wrapper(*args: Any, **kwargs: Any) -> Any:
+        operator = args[1]
+        log.debug("Running operator: %s", operator.__class__.__name__)
+        start = time.perf_counter()
+        result = func(*args, **kwargs)
+        op_time = time.perf_counter() - start
+        # Only record time ≥ 0.01s (10ms)
+        if op_time >= 0.01:
+            log.debug("Operator %s finished in %.2f seconds", 
operator.__class__.__name__, op_time)
+            log.debug("Context:\n%s", result)
+        return result
+    return wrapper

Reply via email to