This is an automated email from the ASF dual-hosted git repository.

jin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git


The following commit(s) were added to refs/heads/main by this push:
     new 4ba34bc  feat(llm): add auth for fastapi and gradio (#70)
4ba34bc is described below

commit 4ba34bc48e1f834286237163d38a81c43d164454
Author: chenzihong <[email protected]>
AuthorDate: Tue Aug 20 13:02:03 2024 +0800

    feat(llm): add auth for fastapi and gradio (#70)
    
    1. add auth to gradio(using auth provided by gradio)
    auth – If provided, username and password (or list of username-password 
tuples) required to access the gradio app. Can also provide function that takes 
username and password and returns True if valid login.
    refer: https://www.gradio.app/main/docs/gradio/blocks
    2. add auth to fastapi(using APIRouter)
    refer: https://fastapi.tiangolo.com/reference/apirouter/#fastapi.APIRouter
    
    ---------
    
    Co-authored-by: imbajin <[email protected]>
---
 hugegraph-llm/src/hugegraph_llm/api/rag_api.py     | 12 +++++------
 .../src/hugegraph_llm/demo/rag_web_demo.py         | 25 +++++++++++++++++++---
 2 files changed, 28 insertions(+), 9 deletions(-)

diff --git a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py 
b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
index a9c834c..b6d8068 100644
--- a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
+++ b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 
-from fastapi import FastAPI, status
+from fastapi import status, APIRouter
 
 from hugegraph_llm.api.models.rag_response import RAGResponse
 from hugegraph_llm.config import settings
@@ -23,8 +23,8 @@ from hugegraph_llm.api.models.rag_requests import RAGRequest, 
GraphConfigRequest
 from hugegraph_llm.api.exceptions.rag_exceptions import generate_response
 
 
-def rag_http_api(app: FastAPI, rag_answer_func, apply_graph_conf, 
apply_llm_conf, apply_embedding_conf):
-    @app.post("/rag", status_code=status.HTTP_200_OK)
+def rag_http_api(router: APIRouter, rag_answer_func, apply_graph_conf, 
apply_llm_conf, apply_embedding_conf):
+    @router.post("/rag", status_code=status.HTTP_200_OK)
     def rag_answer_api(req: RAGRequest):
         result = rag_answer_func(req.query, req.raw_llm, req.vector_only, 
req.graph_only, req.graph_vector)
         return {
@@ -33,13 +33,13 @@ def rag_http_api(app: FastAPI, rag_answer_func, 
apply_graph_conf, apply_llm_conf
             if getattr(req, key)
         }
 
-    @app.post("/config/graph", status_code=status.HTTP_201_CREATED)
+    @router.post("/config/graph", status_code=status.HTTP_201_CREATED)
     def graph_config_api(req: GraphConfigRequest):
         # Accept status code
         res = apply_graph_conf(req.ip, req.port, req.name, req.user, req.pwd, 
req.gs, origin_call="http")
         return generate_response(RAGResponse(status_code=res, message="Missing 
Value"))
 
-    @app.post("/config/llm", status_code=status.HTTP_201_CREATED)
+    @router.post("/config/llm", status_code=status.HTTP_201_CREATED)
     def llm_config_api(req: LLMConfigRequest):
         settings.llm_type = req.llm_type
 
@@ -53,7 +53,7 @@ def rag_http_api(app: FastAPI, rag_answer_func, 
apply_graph_conf, apply_llm_conf
             res = apply_llm_conf(req.host, req.port, req.language_model, None, 
origin_call="http")
         return generate_response(RAGResponse(status_code=res, message="Missing 
Value"))
 
-    @app.post("/config/embedding", status_code=status.HTTP_201_CREATED)
+    @router.post("/config/embedding", status_code=status.HTTP_201_CREATED)
     def embedding_config_api(req: LLMConfigRequest):
         settings.embedding_type = req.llm_type
 
diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py 
b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py
index 756cb1c..c924186 100644
--- a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py
+++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py
@@ -24,7 +24,8 @@ import docx
 import gradio as gr
 import requests
 import uvicorn
-from fastapi import FastAPI
+from fastapi import FastAPI, Depends, APIRouter
+from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
 from requests.auth import HTTPBasicAuth
 
 from hugegraph_llm.api.rag_api import rag_http_api
@@ -40,6 +41,19 @@ from hugegraph_llm.utils.hugegraph_utils import 
init_hg_test_data, run_gremlin_q
 from hugegraph_llm.utils.log import log
 from hugegraph_llm.utils.vector_index_utils import clean_vector_index
 
+sec = HTTPBearer()
+
+
+def authenticate(credentials: HTTPAuthorizationCredentials = Depends(sec)):
+    correct_token = os.getenv("TOKEN")
+    if credentials.credentials != correct_token:
+        from fastapi import HTTPException
+        raise HTTPException(
+            status_code=401,
+            detail="Invalid token, please contact the admin",
+            headers={"WWW-Authenticate": "Bearer"},
+        )
+
 
 def rag_answer(
         text: str, raw_answer: bool, vector_only_answer: bool, 
graph_only_answer: bool, graph_vector_answer: bool
@@ -460,12 +474,17 @@ if __name__ == "__main__":
     parser.add_argument("--port", type=int, default=8001, help="port")
     args = parser.parse_args()
     app = FastAPI()
+    app_auth = APIRouter(dependencies=[Depends(authenticate)])
 
     hugegraph_llm = init_rag_ui()
+    rag_http_api(app_auth, rag_answer, apply_graph_config, apply_llm_config, 
apply_embedding_config)
 
-    rag_http_api(app, rag_answer, apply_graph_config, apply_llm_config, 
apply_embedding_config)
+    app.include_router(app_auth)
+    auth_enabled = os.getenv("ENABLE_LOGIN", "False").lower() == "true"
+    log.info("Authentication is %s.", "enabled" if auth_enabled else 
"disabled")
+    # TODO: support multi-user login when need
+    app = gr.mount_gradio_app(app, hugegraph_llm, path="/", auth=("rag", 
os.getenv("TOKEN")) if auth_enabled else None)
 
-    app = gr.mount_gradio_app(app, hugegraph_llm, path="/")
     # Note: set reload to False in production environment
     uvicorn.run(app, host=args.host, port=args.port)
     # TODO: we can't use reload now due to the config 'app' of uvicorn.run

Reply via email to