This is an automated email from the ASF dual-hosted git repository.

jin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git


The following commit(s) were added to refs/heads/main by this push:
     new 4f6414d  feat(llm): support basic rag api(v1) (#64)
4f6414d is described below

commit 4f6414d420b94fd539957a02751f43463d248097
Author: chenzihong <[email protected]>
AuthorDate: Mon Aug 19 02:24:53 2024 +0800

    feat(llm): support basic rag api(v1) (#64)
    
    * refact: support test wenxin/ollama conn
    
    ---------
    
    Co-authored-by: Hongjun Li <[email protected]>
    Co-authored-by: imbajin <[email protected]>
---
 .gitignore                                         |   1 +
 .../api/{rag_api.py => exceptions/__init__.py}     |   0
 .../{rag_api.py => exceptions/rag_exceptions.py}   |  21 ++
 .../api/{rag_api.py => models/__init__.py}         |   0
 .../src/hugegraph_llm/api/models/rag_requests.py   |  52 +++
 .../api/{rag_api.py => models/rag_response.py}     |   7 +
 hugegraph-llm/src/hugegraph_llm/api/rag_api.py     |  50 +++
 hugegraph-llm/src/hugegraph_llm/config/config.py   |   2 +-
 .../src/hugegraph_llm/demo/rag_web_demo.py         | 383 +++++++++++----------
 .../{api/rag_api.py => enums/build_mode.py}        |  11 +
 10 files changed, 350 insertions(+), 177 deletions(-)

diff --git a/.gitignore b/.gitignore
index de24191..786b5e1 100644
--- a/.gitignore
+++ b/.gitignore
@@ -164,3 +164,4 @@ cython_debug/
 #  and can be added to the global gitignore or merged into this file.  For a 
more nuclear
 #  option (not recommended) you can uncomment the following to ignore the 
entire idea folder.
 .idea/
+*.DS_Store
diff --git a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py 
b/hugegraph-llm/src/hugegraph_llm/api/exceptions/__init__.py
similarity index 100%
copy from hugegraph-llm/src/hugegraph_llm/api/rag_api.py
copy to hugegraph-llm/src/hugegraph_llm/api/exceptions/__init__.py
diff --git a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py 
b/hugegraph-llm/src/hugegraph_llm/api/exceptions/rag_exceptions.py
similarity index 50%
copy from hugegraph-llm/src/hugegraph_llm/api/rag_api.py
copy to hugegraph-llm/src/hugegraph_llm/api/exceptions/rag_exceptions.py
index 13a8339..24ef7c1 100644
--- a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
+++ b/hugegraph-llm/src/hugegraph_llm/api/exceptions/rag_exceptions.py
@@ -14,3 +14,24 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+
+from fastapi import HTTPException
+from hugegraph_llm.api.models.rag_response import RAGResponse
+
+
+class ExternalException(HTTPException):
+    def __init__(self):
+        super().__init__(status_code=400, detail="Connect failed with error 
code -1, please check the input.")
+
+
+class ConnectionFailedException(HTTPException):
+    def __init__(self, status_code: int, message: str):
+        super().__init__(status_code=status_code, detail=message)
+
+
+def generate_response(response: RAGResponse) -> dict:
+    if response.status_code == -1:
+        raise ExternalException()
+    elif not (200 <= response.status_code < 300):
+        raise ConnectionFailedException(response.status_code, response.message)
+    return {"message": "Connection successful. Configured finished."}
diff --git a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py 
b/hugegraph-llm/src/hugegraph_llm/api/models/__init__.py
similarity index 100%
copy from hugegraph-llm/src/hugegraph_llm/api/rag_api.py
copy to hugegraph-llm/src/hugegraph_llm/api/models/__init__.py
diff --git a/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py 
b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
new file mode 100644
index 0000000..d12a1b8
--- /dev/null
+++ b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
@@ -0,0 +1,52 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from pydantic import BaseModel
+from typing import Optional
+
+
+class RAGRequest(BaseModel):
+    query: str
+    raw_llm: Optional[bool] = True
+    vector_only: Optional[bool] = False
+    graph_only: Optional[bool] = False
+    graph_vector: Optional[bool] = False
+
+
+class GraphConfigRequest(BaseModel):
+    ip: str = "127.0.0.1"
+    port: str = "8080"
+    name: str = "hugegraph"
+    user: str = "xxx"
+    pwd: str = "xxx"
+    gs: str = None
+
+
+class LLMConfigRequest(BaseModel):
+    llm_type: str
+    # The common parameters shared by OpenAI, Qianfan Wenxin,
+    # and OLLAMA platforms.
+    api_key: str
+    api_base: str
+    language_model: str
+    # Openai-only properties
+    max_tokens: str = None
+    # qianfan-wenxin-only properties
+    secret_key: str = None
+    # ollama-only properties
+    host: str = None
+    port: str = None
diff --git a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py 
b/hugegraph-llm/src/hugegraph_llm/api/models/rag_response.py
similarity index 87%
copy from hugegraph-llm/src/hugegraph_llm/api/rag_api.py
copy to hugegraph-llm/src/hugegraph_llm/api/models/rag_response.py
index 13a8339..fe139ee 100644
--- a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
+++ b/hugegraph-llm/src/hugegraph_llm/api/models/rag_response.py
@@ -14,3 +14,10 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+
+from pydantic import BaseModel
+
+
+class RAGResponse(BaseModel):
+    status_code: int = -1
+    message: str = ""
diff --git a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py 
b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
index 13a8339..a9c834c 100644
--- a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
+++ b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
@@ -14,3 +14,53 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+
+from fastapi import FastAPI, status
+
+from hugegraph_llm.api.models.rag_response import RAGResponse
+from hugegraph_llm.config import settings
+from hugegraph_llm.api.models.rag_requests import RAGRequest, 
GraphConfigRequest, LLMConfigRequest
+from hugegraph_llm.api.exceptions.rag_exceptions import generate_response
+
+
+def rag_http_api(app: FastAPI, rag_answer_func, apply_graph_conf, 
apply_llm_conf, apply_embedding_conf):
+    @app.post("/rag", status_code=status.HTTP_200_OK)
+    def rag_answer_api(req: RAGRequest):
+        result = rag_answer_func(req.query, req.raw_llm, req.vector_only, 
req.graph_only, req.graph_vector)
+        return {
+            key: value
+            for key, value in zip(["raw_llm", "vector_only", "graph_only", 
"graph_vector"], result)
+            if getattr(req, key)
+        }
+
+    @app.post("/config/graph", status_code=status.HTTP_201_CREATED)
+    def graph_config_api(req: GraphConfigRequest):
+        # Accept status code
+        res = apply_graph_conf(req.ip, req.port, req.name, req.user, req.pwd, 
req.gs, origin_call="http")
+        return generate_response(RAGResponse(status_code=res, message="Missing 
Value"))
+
+    @app.post("/config/llm", status_code=status.HTTP_201_CREATED)
+    def llm_config_api(req: LLMConfigRequest):
+        settings.llm_type = req.llm_type
+
+        if req.llm_type == "openai":
+            res = apply_llm_conf(
+                req.api_key, req.api_base, req.language_model, req.max_tokens, 
origin_call="http"
+            )
+        elif req.llm_type == "qianfan_wenxin":
+            res = apply_llm_conf(req.api_key, req.secret_key, 
req.language_model, None, origin_call="http")
+        else:
+            res = apply_llm_conf(req.host, req.port, req.language_model, None, 
origin_call="http")
+        return generate_response(RAGResponse(status_code=res, message="Missing 
Value"))
+
+    @app.post("/config/embedding", status_code=status.HTTP_201_CREATED)
+    def embedding_config_api(req: LLMConfigRequest):
+        settings.embedding_type = req.llm_type
+
+        if req.llm_type == "openai":
+            res = apply_embedding_conf(req.api_key, req.api_base, 
req.language_model, origin_call="http")
+        elif req.llm_type == "qianfan_wenxin":
+            res = apply_embedding_conf(req.api_key, req.api_base, None, 
origin_call="http")
+        else:
+            res = apply_embedding_conf(req.host, req.port, req.language_model, 
origin_call="http")
+        return generate_response(RAGResponse(status_code=res, message="Missing 
Value"))
diff --git a/hugegraph-llm/src/hugegraph_llm/config/config.py 
b/hugegraph-llm/src/hugegraph_llm/config/config.py
index 62d41d4..3659cc1 100644
--- a/hugegraph-llm/src/hugegraph_llm/config/config.py
+++ b/hugegraph-llm/src/hugegraph_llm/config/config.py
@@ -67,7 +67,7 @@ class Config:
     """HugeGraph settings"""
     graph_ip: Optional[str] = "127.0.0.1"
     graph_port: Optional[int] = 8080
-    # graph_space: Optional[str] = "DEFAULT"
+    graph_space: Optional[str] = None
     graph_name: Optional[str] = "hugegraph"
     graph_user: Optional[str] = "admin"
     graph_pwd: Optional[str] = "xxx"
diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py 
b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py
index 01bc85c..756cb1c 100644
--- a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py
+++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py
@@ -16,47 +16,38 @@
 # under the License.
 
 
-import json
 import argparse
+import json
 import os
 
-import requests
-import uvicorn
 import docx
 import gradio as gr
+import requests
+import uvicorn
 from fastapi import FastAPI
+from requests.auth import HTTPBasicAuth
 
-from hugegraph_llm.models.llms.init_llm import LLMs
+from hugegraph_llm.api.rag_api import rag_http_api
+from hugegraph_llm.config import settings, resource_path
+from hugegraph_llm.enums.build_mode import BuildMode
 from hugegraph_llm.models.embeddings.init_embedding import Embeddings
+from hugegraph_llm.models.llms.init_llm import LLMs
 from hugegraph_llm.operators.graph_rag_task import GraphRAG
 from hugegraph_llm.operators.kg_construction_task import KgBuilder
-from hugegraph_llm.config import settings, resource_path
 from hugegraph_llm.operators.llm_op.property_graph_extract import 
SCHEMA_EXAMPLE_PROMPT
-from hugegraph_llm.utils.hugegraph_utils import (
-    init_hg_test_data,
-    run_gremlin_query,
-    clean_hg_data
-)
-from hugegraph_llm.utils.log import log
 from hugegraph_llm.utils.hugegraph_utils import get_hg_client
+from hugegraph_llm.utils.hugegraph_utils import init_hg_test_data, 
run_gremlin_query, clean_hg_data
+from hugegraph_llm.utils.log import log
 from hugegraph_llm.utils.vector_index_utils import clean_vector_index
 
 
-def convert_bool_str(string):
-    if string == "true":
-        return True
-    if string == "false":
-        return False
-    raise gr.Error(f"Invalid boolean string: {string}")
-
+def rag_answer(
+        text: str, raw_answer: bool, vector_only_answer: bool, 
graph_only_answer: bool, graph_vector_answer: bool
+) -> tuple:
+    vector_search = vector_only_answer or graph_vector_answer
+    graph_search = graph_only_answer or graph_vector_answer
 
-# TODO: enhance/distinguish the "graph_rag" name to avoid confusion
-def graph_rag(text: str, raw_answer: str, vector_only_answer: str,
-              graph_only_answer: str, graph_vector_answer):
-    vector_search = convert_bool_str(vector_only_answer) or 
convert_bool_str(graph_vector_answer)
-    graph_search = convert_bool_str(graph_only_answer) or 
convert_bool_str(graph_vector_answer)
-
-    if raw_answer == "false" and not vector_search and not graph_search:
+    if raw_answer is False and not vector_search and not graph_search:
         gr.Warning("Please select at least one generate mode.")
         return "", "", "", ""
     searcher = GraphRAG()
@@ -65,10 +56,10 @@ def graph_rag(text: str, raw_answer: str, 
vector_only_answer: str,
     if graph_search:
         searcher.extract_keyword().match_keyword_to_id().query_graph_for_rag()
     searcher.merge_dedup_rerank().synthesize_answer(
-        raw_answer=convert_bool_str(raw_answer),
-        vector_only_answer=convert_bool_str(vector_only_answer),
-        graph_only_answer=convert_bool_str(graph_only_answer),
-        graph_vector_answer=convert_bool_str(graph_vector_answer)
+        raw_answer=raw_answer,
+        vector_only_answer=vector_only_answer,
+        graph_only_answer=graph_only_answer,
+        graph_vector_answer=graph_vector_answer,
     ).run(verbose=True, query=text)
 
     try:
@@ -77,7 +68,7 @@ def graph_rag(text: str, raw_answer: str, vector_only_answer: 
str,
             context.get("raw_answer", ""),
             context.get("vector_only_answer", ""),
             context.get("graph_only_answer", ""),
-            context.get("graph_vector_answer", "")
+            context.get("graph_vector_answer", ""),
         )
     except ValueError as e:
         log.error(e)
@@ -87,7 +78,7 @@ def graph_rag(text: str, raw_answer: str, vector_only_answer: 
str,
         raise gr.Error(f"An unexpected error occurred: {str(e)}")
 
 
-def build_kg(file, schema, example_prompt, build_mode):  # pylint: 
disable=too-many-branches
+def build_kg(file, schema, example_prompt, build_mode) -> str:  # pylint: 
disable=too-many-branches
     full_path = file.name
     if full_path.endswith(".txt"):
         with open(full_path, "r", encoding="utf-8") as f:
@@ -99,12 +90,13 @@ def build_kg(file, schema, example_prompt, build_mode):  # 
pylint: disable=too-m
             text += para.text
             text += "\n"
     elif full_path.endswith(".pdf"):
+        # TODO: support PDF file
         raise gr.Error("PDF will be supported later! Try to upload text/docx 
now")
     else:
         raise gr.Error("Please input txt or docx file.")
     builder = KgBuilder(LLMs().get_llm(), Embeddings().get_embedding(), 
get_hg_client())
 
-    if build_mode != "Rebuild vertex index":
+    if build_mode != BuildMode.REBUILD_VERTEX_INDEX.value:
         if schema:
             try:
                 schema = json.loads(schema.strip())
@@ -116,39 +108,146 @@ def build_kg(file, schema, example_prompt, build_mode):  
# pylint: disable=too-m
             return "ERROR: please input schema."
     builder.chunk_split(text, "paragraph", "zh")
 
-    # TODO: avoid hardcoding the "build_mode" strings (use var/constant 
instead)
-    if build_mode == "Rebuild Vector":
+    if build_mode == BuildMode.REBUILD_VECTOR.value:
         builder.fetch_graph_data()
     else:
         builder.extract_info(example_prompt, "property_graph")
     # "Test Mode", "Import Mode", "Clear and Import", "Rebuild Vector"
-    if build_mode != "Test Mode":
-        if build_mode in ("Clear and Import", "Rebuild Vector"):
+    if build_mode != BuildMode.TEST_MODE.value:
+        if build_mode in (BuildMode.CLEAR_AND_IMPORT.value, 
BuildMode.REBUILD_VECTOR.value):
             clean_vector_index()
         builder.build_vector_index()
-    if build_mode == "Clear and Import":
+    if build_mode == BuildMode.CLEAR_AND_IMPORT.value:
         clean_hg_data()
-    if build_mode in ("Clear and Import", "Import Mode"):
+    if build_mode in (BuildMode.CLEAR_AND_IMPORT.value, 
BuildMode.IMPORT_MODE.value):
         builder.commit_to_hugegraph()
-    if build_mode != "Test Mode":
+    if build_mode != BuildMode.TEST_MODE.value:
         builder.build_vertex_id_semantic_index()
     log.debug(builder.operators)
     try:
         context = builder.run()
-        return context
+        return str(context)
     except Exception as e:  # pylint: disable=broad-exception-caught
         log.error(e)
         raise gr.Error(str(e))
 
 
-if __name__ == "__main__":
-    parser = argparse.ArgumentParser()
-    parser.add_argument("--host", type=str, default="0.0.0.0", help="host")
-    parser.add_argument("--port", type=int, default=8001, help="port")
-    args = parser.parse_args()
-    app = FastAPI()
-
-    with gr.Blocks() as hugegraph_llm:
+def test_api_connection(url, method="GET",
+                        headers=None, params=None, body=None, auth=None, 
origin_call=None) -> int:
+    # TODO: use fastapi.request / starlette instead?
+    log.debug("Request URL: %s", url)
+    try:
+        if method.upper() == "GET":
+            resp = requests.get(url, headers=headers, params=params, 
timeout=5, auth=auth)
+        elif method.upper() == "POST":
+            resp = requests.post(url, headers=headers, params=params, 
json=body, timeout=5, auth=auth)
+        else:
+            raise ValueError("Unsupported HTTP method, please use GET/POST 
instead")
+    except requests.exceptions.RequestException as e:
+        msg = f"Connection failed: {e}"
+        log.error(msg)
+        if origin_call is None:
+            raise gr.Error(msg)
+        return -1  # Error code
+
+    if 200 <= resp.status_code < 300:
+        msg = "Test connection successful~"
+        log.info(msg)
+        gr.Info(msg)
+    else:
+        msg = f"Connection failed with status code: {resp.status_code}, error: 
{resp.text}"
+        log.error(msg)
+        # TODO: Only the message returned by rag can be processed, and the 
other return values can't be processed
+        if origin_call is None:
+            raise gr.Error(json.loads(resp.text).get("message", msg))
+    return resp.status_code
+
+
+def config_qianfan_model(arg1, arg2, arg3=None, origin_call=None) -> int:
+    settings.qianfan_api_key = arg1
+    settings.qianfan_secret_key = arg2
+    settings.qianfan_language_model = arg3
+    params = {
+        "grant_type": "client_credentials",
+        "client_id": arg1,
+        "client_secret": arg2
+    }
+    status_code = 
test_api_connection("https://aip.baidubce.com/oauth/2.0/token";, "POST", 
params=params,
+                                      origin_call=origin_call)
+    return status_code
+
+
+def apply_embedding_config(arg1, arg2, arg3, origin_call=None) -> int:
+    status_code = -1
+    embedding_option = settings.embedding_type
+    if embedding_option == "openai":
+        settings.openai_api_key = arg1
+        settings.openai_api_base = arg2
+        settings.openai_embedding_model = arg3
+        test_url = settings.openai_api_base + "/models"
+        headers = {"Authorization": f"Bearer {arg1}"}
+        status_code = test_api_connection(test_url, headers=headers, 
origin_call=origin_call)
+    elif embedding_option == "qianfan_wenxin":
+        status_code = config_qianfan_model(arg1, arg2, origin_call=origin_call)
+        settings.qianfan_embedding_model = arg3
+    elif embedding_option == "ollama":
+        settings.ollama_host = arg1
+        settings.ollama_port = int(arg2)
+        settings.ollama_embedding_model = arg3
+        # TODO: right way to test ollama conn?
+        status_code = test_api_connection(f"http://{arg1}:{arg2}/status";, 
origin_call=origin_call)
+    settings.update_env()
+    gr.Info("Configured!")
+    return status_code
+
+
+def apply_graph_config(ip, port, name, user, pwd, gs, origin_call=None) -> int:
+    settings.graph_ip = ip
+    settings.graph_port = int(port)
+    settings.graph_name = name
+    settings.graph_user = user
+    settings.graph_pwd = pwd
+    settings.graph_space = gs
+    # Test graph connection (Auth)
+    if gs and gs.strip():
+        test_url = f"http://{ip}:{port}/graphspaces/{gs}/graphs/{name}/schema";
+    else:
+        test_url = f"http://{ip}:{port}/graphs/{name}/schema";
+    auth = HTTPBasicAuth(user, pwd)
+    # for http api return status
+    response = test_api_connection(test_url, auth=auth, 
origin_call=origin_call)
+    settings.update_env()
+    return response
+
+
+# Different llm models have different parameters,
+# so no meaningful argument names are given here
+def apply_llm_config(arg1, arg2, arg3, arg4, origin_call=None) -> int:
+    llm_option = settings.llm_type
+    status_code = -1
+    if llm_option == "openai":
+        settings.openai_api_key = arg1
+        settings.openai_api_base = arg2
+        settings.openai_language_model = arg3
+        settings.openai_max_tokens = int(arg4)
+        test_url = settings.openai_api_base + "/models"
+        headers = {"Authorization": f"Bearer {arg1}"}
+        status_code = test_api_connection(test_url, headers=headers, 
origin_call=origin_call)
+    elif llm_option == "qianfan_wenxin":
+        status_code = config_qianfan_model(arg1, arg2, arg3, origin_call)
+    elif llm_option == "ollama":
+        settings.ollama_host = arg1
+        settings.ollama_port = int(arg2)
+        settings.ollama_language_model = arg3
+        # TODO: right way to test ollama conn?
+        status_code = test_api_connection(f"http://{arg1}:{arg2}/status";, 
origin_call=origin_call)
+    gr.Info("Configured!")
+    settings.update_env()
+    return status_code
+
+
+def init_rag_ui() -> gr.Interface:
+    with gr.Blocks() as hugegraph_llm_ui:
         gr.Markdown(
             """# HugeGraph LLM RAG Demo
         1. Set up the HugeGraph server."""
@@ -159,51 +258,17 @@ if __name__ == "__main__":
                 gr.Textbox(value=str(settings.graph_port), label="port"),
                 gr.Textbox(value=settings.graph_name, label="graph"),
                 gr.Textbox(value=settings.graph_user, label="user"),
-                gr.Textbox(value=settings.graph_pwd, label="pwd")
+                gr.Textbox(value=settings.graph_pwd, label="pwd", 
type="password"),
+                # gr.Textbox(value=settings.graph_space, label="graphspace 
(None)"),
+                # wip: graph_space issue pending
+                gr.Textbox(value="", label="graphspace (None)"),
             ]
         graph_config_button = gr.Button("apply configuration")
 
-
-        def test_api_connection(url, method="GET", ak=None, sk=None, 
headers=None, body=None):
-            # TODO: use fastapi.request / starlette instead? (Also add a 
try-catch here)
-            log.debug("Request URL: %s", url)
-            if method.upper() == "GET":
-                response = requests.get(url, headers=headers, timeout=5)
-            elif method.upper() == "POST":
-                response = requests.post(url, headers=headers, json=body, 
timeout=5)
-            else:
-                log.error("Unsupported method: %s", method)
-                return
-
-            if 200 <= response.status_code < 300:
-                log.info("Connection successful. Configured finished.")
-                gr.Info("Connection successful. Configured finished.")
-            else:
-                log.error("Connection failed with status code: %s", 
response.status_code)
-                # pylint: disable=pointless-exception-statement
-                gr.Error(f"Connection failed with status code: 
{response.status_code}")
-
-
-        def apply_graph_configuration(ip, port, name, user, pwd):
-            settings.graph_ip = ip
-            settings.graph_port = int(port)
-            settings.graph_name = name
-            settings.graph_user = user
-            settings.graph_pwd = pwd
-            test_url = f"http://{ip}:{port}/graphs/{name}/schema";
-            test_api_connection(test_url)
-            settings.update_env()
-
-
-        graph_config_button.click(apply_graph_configuration, 
inputs=graph_config_input)  # pylint: disable=no-member
+        graph_config_button.click(apply_graph_config, 
inputs=graph_config_input)  # pylint: disable=no-member
 
         gr.Markdown("2. Set up the LLM.")
-        llm_dropdown = gr.Dropdown(
-            choices=["openai", "qianfan_wenxin", "ollama"],
-            value=settings.llm_type,
-            label="LLM"
-        )
-
+        llm_dropdown = gr.Dropdown(choices=["openai", "qianfan_wenxin", 
"ollama"], value=settings.llm_type, label="LLM")
 
         @gr.render(inputs=[llm_dropdown])
         def llm_settings(llm_type):
@@ -222,58 +287,28 @@ if __name__ == "__main__":
                         gr.Textbox(value=settings.ollama_host, label="host"),
                         gr.Textbox(value=str(settings.ollama_port), 
label="port"),
                         gr.Textbox(value=settings.ollama_language_model, 
label="model_name"),
-                        gr.Textbox(value="", visible=False)
+                        gr.Textbox(value="", visible=False),
                     ]
             elif llm_type == "qianfan_wenxin":
                 with gr.Row():
                     llm_config_input = [
-                        gr.Textbox(value=settings.qianfan_api_key, 
label="api_key",
-                                   type="password"),
-                        gr.Textbox(value=settings.qianfan_secret_key, 
label="secret_key",
-                                   type="password"),
+                        gr.Textbox(value=settings.qianfan_api_key, 
label="api_key", type="password"),
+                        gr.Textbox(value=settings.qianfan_secret_key, 
label="secret_key", type="password"),
                         gr.Textbox(value=settings.qianfan_language_model, 
label="model_name"),
-                        gr.Textbox(value="", visible=False)
+                        gr.Textbox(value="", visible=False),
                     ]
                 log.debug(llm_config_input)
             else:
                 llm_config_input = []
             llm_config_button = gr.Button("apply configuration")
 
-            def apply_llm_configuration(arg1, arg2, arg3, arg4):
-                llm_option = settings.llm_type
-
-                if llm_option == "openai":
-                    settings.openai_api_key = arg1
-                    settings.openai_api_base = arg2
-                    settings.openai_language_model = arg3
-                    settings.openai_max_tokens = int(arg4)
-                    test_url = settings.openai_api_base + "/models"
-                    headers = {"Authorization": f"Bearer {arg1}"}
-                    test_api_connection(test_url, headers=headers, ak=arg1)
-                elif llm_option == "qianfan_wenxin":
-                    settings.qianfan_api_key = arg1
-                    settings.qianfan_secret_key = arg2
-                    settings.qianfan_language_model = arg3
-                    # TODO: test the connection
-                    # test_url = "https://aip.baidubce.com/oauth/2.0/token";  # 
POST
-                elif llm_option == "ollama":
-                    settings.ollama_host = arg1
-                    settings.ollama_port = int(arg2)
-                    settings.ollama_language_model = arg3
-                gr.Info("configured!")
-                settings.update_env()
-
-            llm_config_button.click(apply_llm_configuration, 
inputs=llm_config_input)  # pylint: disable=no-member
-
+            llm_config_button.click(apply_llm_config, inputs=llm_config_input) 
 # pylint: disable=no-member
 
         gr.Markdown("3. Set up the Embedding.")
         embedding_dropdown = gr.Dropdown(
-            choices=["openai", "ollama", "qianfan_wenxin"],
-            value=settings.embedding_type,
-            label="Embedding"
+            choices=["openai", "qianfan_wenxin", "ollama"], 
value=settings.embedding_type, label="Embedding"
         )
 
-
         @gr.render(inputs=[embedding_dropdown])
         def embedding_settings(embedding_type):
             settings.embedding_type = embedding_type
@@ -282,15 +317,13 @@ if __name__ == "__main__":
                     embedding_config_input = [
                         gr.Textbox(value=settings.openai_api_key, 
label="api_key", type="password"),
                         gr.Textbox(value=settings.openai_api_base, 
label="api_base"),
-                        gr.Textbox(value=settings.openai_embedding_model, 
label="model_name")
+                        gr.Textbox(value=settings.openai_embedding_model, 
label="model_name"),
                     ]
             elif embedding_type == "qianfan_wenxin":
                 with gr.Row():
                     embedding_config_input = [
-                        gr.Textbox(value=settings.qianfan_api_key, 
label="api_key",
-                                   type="password"),
-                        gr.Textbox(value=settings.qianfan_secret_key, 
label="secret_key",
-                                   type="password"),
+                        gr.Textbox(value=settings.qianfan_api_key, 
label="api_key", type="password"),
+                        gr.Textbox(value=settings.qianfan_secret_key, 
label="secret_key", type="password"),
                         gr.Textbox(value=settings.qianfan_embedding_model, 
label="model_name"),
                     ]
             elif embedding_type == "ollama":
@@ -302,31 +335,13 @@ if __name__ == "__main__":
                     ]
             else:
                 embedding_config_input = []
-            embedding_config_button = gr.Button("apply configuration")
 
-            def apply_embedding_configuration(arg1, arg2, arg3):
-                embedding_option = settings.embedding_type
-                if embedding_option == "openai":
-                    settings.openai_api_key = arg1
-                    settings.openai_api_base = arg2
-                    settings.openai_embedding_model = arg3
-                    test_url = settings.openai_api_base + "/models"
-                    headers = {"Authorization": f"Bearer {arg1}"}
-                    test_api_connection(test_url, headers=headers, ak=arg1)
-                elif embedding_option == "ollama":
-                    settings.ollama_host = arg1
-                    settings.ollama_port = int(arg2)
-                    settings.ollama_embedding_model = arg3
-                elif embedding_option == "qianfan_wenxin":
-                    settings.qianfan_access_token = arg1
-                    settings.qianfan_embed_url = arg2
-                settings.update_env()
-
-                gr.Info("configured!")
-
-            embedding_config_button.click(apply_embedding_configuration,  # 
pylint: disable=no-member
-                                          inputs=embedding_config_input)
+            embedding_config_button = gr.Button("apply configuration")
 
+            # Call the separate apply_embedding_configuration function here
+            embedding_config_button.click(
+                apply_embedding_config, inputs=embedding_config_input  # 
pylint: disable=no-member
+            )
 
         gr.Markdown(
             """## 1. Build vector/graph RAG (💡)
@@ -344,7 +359,7 @@ if __name__ == "__main__":
 """
         )
 
-        SCHEMA = """{
+        schema = """{
   "vertexlabels": [
     {
       "id":1,
@@ -380,21 +395,20 @@ if __name__ == "__main__":
 }"""
 
         with gr.Row():
-            input_file = gr.File(value=os.path.join(resource_path, "demo", 
"test.txt"),
-                                 label="Document")
-            input_schema = gr.Textbox(value=SCHEMA, label="Schema")
-            info_extract_template = gr.Textbox(value=SCHEMA_EXAMPLE_PROMPT,
-                                               label="Info extract head")
+            input_file = gr.File(value=os.path.join(resource_path, "demo", 
"test.txt"), label="Document")
+            input_schema = gr.Textbox(value=schema, label="Schema")
+            info_extract_template = gr.Textbox(value=SCHEMA_EXAMPLE_PROMPT, 
label="Info extract head")
             with gr.Column():
-                mode = gr.Radio(choices=["Test Mode", "Import Mode", "Clear 
and Import", "Rebuild Vector"],
-                                value="Test Mode", label="Build mode")
+                mode = gr.Radio(
+                    choices=["Test Mode", "Import Mode", "Clear and Import", 
"Rebuild Vector"],
+                    value="Test Mode",
+                    label="Build mode",
+                )
                 btn = gr.Button("Build Vector/Graph RAG")
         with gr.Row():
             out = gr.Textbox(label="Output", show_copy_button=True)
         btn.click(  # pylint: disable=no-member
-            fn=build_kg,
-            inputs=[input_file, input_schema, info_extract_template, mode],
-            outputs=out
+            fn=build_kg, inputs=[input_file, input_schema, 
info_extract_template, mode], outputs=out
         )
 
         gr.Markdown("""## 2. RAG with HugeGraph 📖""")
@@ -406,33 +420,50 @@ if __name__ == "__main__":
                 graph_only_out = gr.Textbox(label="Graph-only Answer", 
show_copy_button=True)
                 graph_vector_out = gr.Textbox(label="Graph-Vector Answer", 
show_copy_button=True)
             with gr.Column(scale=1):
-                raw_radio = gr.Radio(choices=["true", "false"], value="false",
-                                     label="Basic LLM Answer")
-                vector_only_radio = gr.Radio(choices=["true", "false"], 
value="true",
-                                             label="Vector-only Answer")
-                graph_only_radio = gr.Radio(choices=["true", "false"], 
value="false",
-                                            label="Graph-only Answer")
-                graph_vector_radio = gr.Radio(choices=["true", "false"], 
value="false",
-                                              label="Graph-Vector Answer")
+                raw_radio = gr.Radio(choices=[True, False], value=True, 
label="Basic LLM Answer")
+                vector_only_radio = gr.Radio(choices=[True, False], 
value=False, label="Vector-only Answer")
+                graph_only_radio = gr.Radio(choices=[True, False], 
value=False, label="Graph-only Answer")
+                graph_vector_radio = gr.Radio(choices=[True, False], 
value=False, label="Graph-Vector Answer")
                 btn = gr.Button("Answer Question")
-        btn.click(fn=graph_rag, inputs=[inp, raw_radio, vector_only_radio, 
graph_only_radio, # pylint: disable=no-member
-                                        graph_vector_radio],
-                  outputs=[raw_out, vector_only_out, graph_only_out, 
graph_vector_out])
+        btn.click(
+            fn=rag_answer,
+            inputs=[
+                inp,
+                raw_radio,
+                vector_only_radio,
+                graph_only_radio,  # pylint: disable=no-member
+                graph_vector_radio,
+            ],
+            outputs=[raw_out, vector_only_out, graph_only_out, 
graph_vector_out],
+        )
 
         gr.Markdown("""## 3. Others (🚧) """)
         with gr.Row():
             with gr.Column():
                 inp = gr.Textbox(value="g.V().limit(10)", label="Gremlin 
query", show_copy_button=True)
-                format = gr.Checkbox(label="Format JSON", value=True)
+                fmt = gr.Checkbox(label="Format JSON", value=True)
             out = gr.Textbox(label="Output", show_copy_button=True)
         btn = gr.Button("Run gremlin query on HugeGraph")
-        btn.click(fn=run_gremlin_query, inputs=[inp, format], outputs=out)  # 
pylint: disable=no-member
+        btn.click(fn=run_gremlin_query, inputs=[inp, fmt], outputs=out)  # 
pylint: disable=no-member
 
         with gr.Row():
             inp = []
             out = gr.Textbox(label="Output", show_copy_button=True)
         btn = gr.Button("(BETA) Init HugeGraph test data (🚧WIP)")
         btn.click(fn=init_hg_test_data, inputs=inp, outputs=out)  # pylint: 
disable=no-member
+    return hugegraph_llm_ui
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--host", type=str, default="0.0.0.0", help="host")
+    parser.add_argument("--port", type=int, default=8001, help="port")
+    args = parser.parse_args()
+    app = FastAPI()
+
+    hugegraph_llm = init_rag_ui()
+
+    rag_http_api(app, rag_answer, apply_graph_config, apply_llm_config, 
apply_embedding_config)
 
     app = gr.mount_gradio_app(app, hugegraph_llm, path="/")
     # Note: set reload to False in production environment
diff --git a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py 
b/hugegraph-llm/src/hugegraph_llm/enums/build_mode.py
similarity index 76%
copy from hugegraph-llm/src/hugegraph_llm/api/rag_api.py
copy to hugegraph-llm/src/hugegraph_llm/enums/build_mode.py
index 13a8339..50db4c8 100644
--- a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
+++ b/hugegraph-llm/src/hugegraph_llm/enums/build_mode.py
@@ -14,3 +14,14 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+
+
+from enum import Enum
+
+
+class BuildMode(Enum):
+    REBUILD_VECTOR = "Rebuild Vector"
+    TEST_MODE = "Test Mode"
+    IMPORT_MODE = "Import Mode"
+    CLEAR_AND_IMPORT = "Clear and Import"
+    REBUILD_VERTEX_INDEX = "Rebuild vertex index"


Reply via email to