This is an automated email from the ASF dual-hosted git repository.
jin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git
The following commit(s) were added to refs/heads/main by this push:
new 4ba34bc feat(llm): add auth for fastapi and gradio (#70)
4ba34bc is described below
commit 4ba34bc48e1f834286237163d38a81c43d164454
Author: chenzihong <[email protected]>
AuthorDate: Tue Aug 20 13:02:03 2024 +0800
feat(llm): add auth for fastapi and gradio (#70)
1. add auth to gradio(using auth provided by gradio)
auth – If provided, username and password (or list of username-password
tuples) required to access the gradio app. Can also provide function that takes
username and password and returns True if valid login.
refer: https://www.gradio.app/main/docs/gradio/blocks
2. add auth to fastapi(using APIRouter)
refer: https://fastapi.tiangolo.com/reference/apirouter/#fastapi.APIRouter
---------
Co-authored-by: imbajin <[email protected]>
---
hugegraph-llm/src/hugegraph_llm/api/rag_api.py | 12 +++++------
.../src/hugegraph_llm/demo/rag_web_demo.py | 25 +++++++++++++++++++---
2 files changed, 28 insertions(+), 9 deletions(-)
diff --git a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
index a9c834c..b6d8068 100644
--- a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
+++ b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
-from fastapi import FastAPI, status
+from fastapi import status, APIRouter
from hugegraph_llm.api.models.rag_response import RAGResponse
from hugegraph_llm.config import settings
@@ -23,8 +23,8 @@ from hugegraph_llm.api.models.rag_requests import RAGRequest,
GraphConfigRequest
from hugegraph_llm.api.exceptions.rag_exceptions import generate_response
-def rag_http_api(app: FastAPI, rag_answer_func, apply_graph_conf,
apply_llm_conf, apply_embedding_conf):
- @app.post("/rag", status_code=status.HTTP_200_OK)
+def rag_http_api(router: APIRouter, rag_answer_func, apply_graph_conf,
apply_llm_conf, apply_embedding_conf):
+ @router.post("/rag", status_code=status.HTTP_200_OK)
def rag_answer_api(req: RAGRequest):
result = rag_answer_func(req.query, req.raw_llm, req.vector_only,
req.graph_only, req.graph_vector)
return {
@@ -33,13 +33,13 @@ def rag_http_api(app: FastAPI, rag_answer_func,
apply_graph_conf, apply_llm_conf
if getattr(req, key)
}
- @app.post("/config/graph", status_code=status.HTTP_201_CREATED)
+ @router.post("/config/graph", status_code=status.HTTP_201_CREATED)
def graph_config_api(req: GraphConfigRequest):
# Accept status code
res = apply_graph_conf(req.ip, req.port, req.name, req.user, req.pwd,
req.gs, origin_call="http")
return generate_response(RAGResponse(status_code=res, message="Missing
Value"))
- @app.post("/config/llm", status_code=status.HTTP_201_CREATED)
+ @router.post("/config/llm", status_code=status.HTTP_201_CREATED)
def llm_config_api(req: LLMConfigRequest):
settings.llm_type = req.llm_type
@@ -53,7 +53,7 @@ def rag_http_api(app: FastAPI, rag_answer_func,
apply_graph_conf, apply_llm_conf
res = apply_llm_conf(req.host, req.port, req.language_model, None,
origin_call="http")
return generate_response(RAGResponse(status_code=res, message="Missing
Value"))
- @app.post("/config/embedding", status_code=status.HTTP_201_CREATED)
+ @router.post("/config/embedding", status_code=status.HTTP_201_CREATED)
def embedding_config_api(req: LLMConfigRequest):
settings.embedding_type = req.llm_type
diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py
b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py
index 756cb1c..c924186 100644
--- a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py
+++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py
@@ -24,7 +24,8 @@ import docx
import gradio as gr
import requests
import uvicorn
-from fastapi import FastAPI
+from fastapi import FastAPI, Depends, APIRouter
+from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from requests.auth import HTTPBasicAuth
from hugegraph_llm.api.rag_api import rag_http_api
@@ -40,6 +41,19 @@ from hugegraph_llm.utils.hugegraph_utils import
init_hg_test_data, run_gremlin_q
from hugegraph_llm.utils.log import log
from hugegraph_llm.utils.vector_index_utils import clean_vector_index
+sec = HTTPBearer()
+
+
+def authenticate(credentials: HTTPAuthorizationCredentials = Depends(sec)):
+ correct_token = os.getenv("TOKEN")
+ if credentials.credentials != correct_token:
+ from fastapi import HTTPException
+ raise HTTPException(
+ status_code=401,
+ detail="Invalid token, please contact the admin",
+ headers={"WWW-Authenticate": "Bearer"},
+ )
+
def rag_answer(
text: str, raw_answer: bool, vector_only_answer: bool,
graph_only_answer: bool, graph_vector_answer: bool
@@ -460,12 +474,17 @@ if __name__ == "__main__":
parser.add_argument("--port", type=int, default=8001, help="port")
args = parser.parse_args()
app = FastAPI()
+ app_auth = APIRouter(dependencies=[Depends(authenticate)])
hugegraph_llm = init_rag_ui()
+ rag_http_api(app_auth, rag_answer, apply_graph_config, apply_llm_config,
apply_embedding_config)
- rag_http_api(app, rag_answer, apply_graph_config, apply_llm_config,
apply_embedding_config)
+ app.include_router(app_auth)
+ auth_enabled = os.getenv("ENABLE_LOGIN", "False").lower() == "true"
+ log.info("Authentication is %s.", "enabled" if auth_enabled else
"disabled")
+ # TODO: support multi-user login when need
+ app = gr.mount_gradio_app(app, hugegraph_llm, path="/", auth=("rag",
os.getenv("TOKEN")) if auth_enabled else None)
- app = gr.mount_gradio_app(app, hugegraph_llm, path="/")
# Note: set reload to False in production environment
uvicorn.run(app, host=args.host, port=args.port)
# TODO: we can't use reload now due to the config 'app' of uvicorn.run