This is an automated email from the ASF dual-hosted git repository.
jin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git
The following commit(s) were added to refs/heads/main by this push:
new c89eb31 feat: refactor the vector index & enhance the UI for building
KG (#75)
c89eb31 is described below
commit c89eb31684e9c368bea4bbe65f0f9d16d1412845
Author: vichayturen <[email protected]>
AuthorDate: Tue Sep 17 14:38:47 2024 +0800
feat: refactor the vector index & enhance the UI for building KG (#75)
1. Added text box extraction method.
2. Refactored vector index.
3. Supported differential deletion and incremental import.
4. Refactored demo interface, add buttons.
TODOs:
- Need move the property check to "property_graph_extract.py"
- ....
---------
Co-authored-by: imbajin <[email protected]>
---
hugegraph-llm/setup.py | 2 +-
hugegraph-llm/src/hugegraph_llm/config/config.py | 20 +-
.../src/hugegraph_llm/config/config_data.py | 16 +-
.../src/hugegraph_llm/demo/rag_web_demo.py | 460 +++++++++------------
.../src/hugegraph_llm/indices/vector_index.py | 56 ++-
.../src/hugegraph_llm/models/llms/ollama.py | 2 +-
.../src/hugegraph_llm/models/llms/openai.py | 2 +-
.../src/hugegraph_llm/models/llms/qianfan.py | 2 +-
.../src/hugegraph_llm/models/rerankers/cohere.py | 9 +-
.../models/rerankers/init_reranker.py | 5 +-
.../hugegraph_llm/models/rerankers/siliconflow.py | 7 +-
.../operators/common_op/check_schema.py | 11 +-
.../operators/common_op/merge_dedup_rerank.py | 5 +-
.../operators/hugegraph_op/commit_to_hugegraph.py | 96 +++--
.../operators/hugegraph_op/fetch_graph_data.py | 2 +
.../operators/hugegraph_op/graph_rag_query.py | 35 +-
.../operators/hugegraph_op/schema_manager.py | 12 +-
.../index_op/build_gremlin_example_index.py | 8 +-
.../operators/index_op/build_semantic_index.py | 35 +-
.../operators/index_op/build_vector_index.py | 14 +-
.../index_op/gremlin_example_index_query.py | 7 +-
.../operators/index_op/semantic_id_query.py | 53 ++-
.../operators/index_op/vector_index_query.py | 6 +-
.../operators/kg_construction_task.py | 5 +-
.../operators/llm_op/answer_synthesize.py | 19 +-
.../operators/llm_op/keyword_extract.py | 5 +-
.../operators/llm_op/property_graph_extract.py | 77 ++--
.../fetch_graph_data.py => resources/demo/css.py} | 32 +-
.../src/hugegraph_llm/utils/decorators.py | 5 +-
.../src/hugegraph_llm/utils/graph_index_utils.py | 121 ++++++
.../src/hugegraph_llm/utils/hugegraph_utils.py | 6 +-
.../src/hugegraph_llm/utils/vector_index_utils.py | 63 ++-
hugegraph-python-client/setup.py | 2 +-
.../src/pyhugegraph/utils/huge_requests.py | 5 +-
.../src/pyhugegraph/utils/util.py | 16 +-
35 files changed, 719 insertions(+), 502 deletions(-)
diff --git a/hugegraph-llm/setup.py b/hugegraph-llm/setup.py
index f33ae23..ad9877f 100644
--- a/hugegraph-llm/setup.py
+++ b/hugegraph-llm/setup.py
@@ -26,7 +26,7 @@ with open("requirements.txt", encoding="utf-8") as fp:
setuptools.setup(
name="hugegraph-llm",
- version="1.3.0",
+ version="1.5.0",
author="Apache HugeGraph Contributors",
author_email="[email protected]",
install_requires=install_requires,
diff --git a/hugegraph-llm/src/hugegraph_llm/config/config.py
b/hugegraph-llm/src/hugegraph_llm/config/config.py
index e70fbde..f6fd734 100644
--- a/hugegraph-llm/src/hugegraph_llm/config/config.py
+++ b/hugegraph-llm/src/hugegraph_llm/config/config.py
@@ -106,28 +106,28 @@ class PromptConfig(PromptData):
def save_to_yaml(self):
- indented_schema = "\n".join([f" {line}" for line in
self.rag_schema.splitlines()])
- indented_example_prompt = "\n".join([f" {line}" for line in
self.schema_example_prompt.splitlines()])
- indented_question = "\n".join([f" {line}" for line in
self.question.splitlines()])
+ indented_schema = "\n".join([f" {line}" for line in
self.graph_schema.splitlines()])
+ indented_example_prompt = "\n".join([f" {line}" for line in
self.extract_graph_prompt.splitlines()])
+ indented_question = "\n".join([f" {line}" for line in
self.default_question.splitlines()])
indented_custom_related_information = (
- "\n".join([f" {line}" for line in
self.custom_related_information.splitlines()])
+ "\n".join([f" {line}" for line in
self.custom_rerank_info.splitlines()])
)
- indented_default_answer_template = "\n".join([f" {line}" for line
in self.default_answer_template.splitlines()])
+ indented_default_answer_template = "\n".join([f" {line}" for line
in self.answer_prompt.splitlines()])
# This can be extended to add storage fields according to the data
needs to be stored
- yaml_content = f"""rag_schema: |
+ yaml_content = f"""graph_schema: |
{indented_schema}
-schema_example_prompt: |
+extract_graph_prompt: |
{indented_example_prompt}
-question: |
+default_question: |
{indented_question}
-custom_related_information: |
+custom_rerank_info: |
{indented_custom_related_information}
-default_answer_template: |
+answer_prompt: |
{indented_default_answer_template}
"""
diff --git a/hugegraph-llm/src/hugegraph_llm/config/config_data.py
b/hugegraph-llm/src/hugegraph_llm/config/config_data.py
index b2a4fd1..36d5fe8 100644
--- a/hugegraph-llm/src/hugegraph_llm/config/config_data.py
+++ b/hugegraph-llm/src/hugegraph_llm/config/config_data.py
@@ -74,7 +74,7 @@ class ConfigData:
class PromptData:
# Data is detached from
hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py
- default_answer_template = f"""You are an expert in knowledge graphs and
natural language processing.
+ answer_prompt = f"""You are an expert in knowledge graphs and natural
language processing.
Your task is to provide a precise and accurate answer based on the given
context.
Context information is below.
@@ -88,12 +88,12 @@ Query: {{query_str}}
Answer:
"""
- custom_related_information = """"""
+ custom_rerank_info = """"""
- question = """Tell me about Sarah."""
+ default_question = """Tell me about Sarah."""
# Data is detached from
hugegraph-llm/src/hugegraph_llm/operators/llm_op/property_graph_extract.py
- schema_example_prompt = """## Main Task
+ extract_graph_prompt = """## Main Task
Given the following graph schema and a piece of text, your task is to analyze
the text and extract information that fits into the schema's structure,
formatting the information into vertices and edges as specified.
## Basic Rules
### Schema Format
@@ -107,10 +107,10 @@ Please read the provided text carefully and identify any
information that corres
#### Edge Format:
{"label":"edgeLabel","type":"edge","outV":"sourceVertexId","outVLabel":"sourceVertexLabel","inV":"targetVertexId","inVLabel":"targetVertexLabel","properties":{"propertyName":"propertyValue",...}}
Also follow the rules:
-1. Don't extract property fields that do not exist in the given schema
-2. Ensure the extracted property is in the same type as the schema (like 'age'
should be a number)
+1. Don't extract property fields or labels that doesn't exist in the given
schema
+2. Ensure the extracted property set in the same type as the given schema
(like 'age' should be a number, 'select' should be a boolean)
3. If there are multiple primary keys, the strategy for generating VID is:
vertexlabelID:pk1!pk2!pk3 (pk means primary key, and '!' is the separator)
-4. Output should be a list of JSON objects, each representing a vertex or an
edge, extracted and formatted based on the text and schema.
+4. Output in JSON format, only include vertexes and edges & remove empty
properties, extracted and formatted based on the text/rules and schema
5. Translate the schema fields into Chinese if the given text is Chinese but
the schema is in English (Optional)
## Example
### Input example:
@@ -122,7 +122,7 @@ Meet Sarah, a 30-year-old attorney, and her roommate,
James, whom she's shared a
[{"id":"1:Sarah","label":"person","type":"vertex","properties":{"name":"Sarah","age":30,"occupation":"attorney"}},{"id":"1:James","label":"person","type":"vertex","properties":{"name":"James","occupation":"journalist"}},{"label":"roommate","type":"edge","outV":"1:Sarah","outVLabel":"person","inV":"1:James","inVLabel":"person","properties":{"date":"2010"}}]
"""
- rag_schema = """{
+ graph_schema = """{
"vertexlabels": [
{
"id": 1,
diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py
b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py
index aaaf0a4..6dc60b1 100644
--- a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py
+++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py
@@ -19,9 +19,8 @@
import argparse
import json
import os
-from typing import List, Union, Tuple, Literal, Optional
+from typing import Tuple, Literal, Optional
-import docx
import gradio as gr
import pandas as pd
import requests
@@ -33,17 +32,14 @@ from requests.auth import HTTPBasicAuth
from hugegraph_llm.api.rag_api import rag_http_api
from hugegraph_llm.config import settings, resource_path, prompt
-from hugegraph_llm.enums.build_mode import BuildMode
-from hugegraph_llm.models.embeddings.init_embedding import Embeddings
-from hugegraph_llm.models.llms.init_llm import LLMs
from hugegraph_llm.operators.graph_rag_task import RAGPipeline
-from hugegraph_llm.operators.kg_construction_task import KgBuilder
from hugegraph_llm.operators.llm_op.property_graph_extract import
SCHEMA_EXAMPLE_PROMPT
-from hugegraph_llm.operators.llm_op.answer_synthesize import
DEFAULT_ANSWER_TEMPLATE
-from hugegraph_llm.utils.hugegraph_utils import get_hg_client
-from hugegraph_llm.utils.hugegraph_utils import init_hg_test_data,
run_gremlin_query, clean_hg_data
+from hugegraph_llm.resources.demo.css import CSS
+from hugegraph_llm.utils.graph_index_utils import get_graph_index_info,
clean_all_graph_index, fit_vid_index, \
+ extract_graph, import_graph_data
+from hugegraph_llm.utils.hugegraph_utils import init_hg_test_data,
run_gremlin_query
from hugegraph_llm.utils.log import log
-from hugegraph_llm.utils.vector_index_utils import clean_vector_index
+from hugegraph_llm.utils.vector_index_utils import clean_vector_index,
build_vector_index, get_vector_index_info
sec = HTTPBearer()
@@ -73,10 +69,10 @@ def rag_answer(
answer_prompt: str,
) -> Tuple:
- if prompt.question != text or prompt.custom_related_information !=
custom_related_information or prompt.default_answer_template != answer_prompt:
- prompt.custom_related_information = custom_related_information
- prompt.question = text
- prompt.default_answer_template = answer_prompt
+ if prompt.default_question != text or prompt.custom_rerank_info !=
custom_related_information or prompt.answer_prompt != answer_prompt:
+ prompt.custom_rerank_info = custom_related_information
+ prompt.default_question = text
+ prompt.answer_prompt = answer_prompt
prompt.update_yaml_file()
vector_search = vector_only_answer or graph_vector_answer
@@ -119,78 +115,6 @@ def rag_answer(
raise gr.Error(f"An unexpected error occurred: {str(e)}")
-def build_kg( # pylint: disable=too-many-branches
- files: Union[NamedString, List[NamedString]],
- schema: str,
- example_prompt: str,
- build_mode: str,
-) -> str:
-
- # update env variables: schema and example_prompt
- if prompt.rag_schema != schema or prompt.schema_example_prompt !=
example_prompt:
- prompt.rag_schema = schema
- prompt.schema_example_prompt = example_prompt
- prompt.update_yaml_file()
-
- if isinstance(files, NamedString):
- files = [files]
- texts = []
- for file in files:
- full_path = file.name
- if full_path.endswith(".txt"):
- with open(full_path, "r", encoding="utf-8") as f:
- texts.append(f.read())
- elif full_path.endswith(".docx"):
- text = ""
- doc = docx.Document(full_path)
- for para in doc.paragraphs:
- text += para.text
- text += "\n"
- texts.append(text)
- elif full_path.endswith(".pdf"):
- # TODO: support PDF file
- raise gr.Error("PDF will be supported later! Try to upload
text/docx now")
- else:
- raise gr.Error("Please input txt or docx file.")
- if build_mode in (BuildMode.CLEAR_AND_IMPORT.value,
BuildMode.REBUILD_VECTOR.value):
- clean_vector_index()
- if build_mode == BuildMode.CLEAR_AND_IMPORT.value:
- clean_hg_data()
- builder = KgBuilder(LLMs().get_llm(), Embeddings().get_embedding(),
get_hg_client())
-
- if build_mode != BuildMode.REBUILD_VERTEX_INDEX.value:
- if schema:
- try:
- schema = json.loads(schema.strip())
- builder.import_schema(from_user_defined=schema)
- except json.JSONDecodeError as e:
- log.error(e)
- builder.import_schema(from_hugegraph=schema)
- else:
- return "ERROR: please input schema."
- builder.chunk_split(texts, "paragraph", "zh")
-
- if build_mode == BuildMode.REBUILD_VECTOR.value:
- builder.fetch_graph_data()
- else:
- builder.extract_info(example_prompt, "property_graph")
-
- # "Test Mode", "Import Mode", "Clear and Import", "Rebuild Vector"
- if build_mode != BuildMode.TEST_MODE.value:
- builder.build_vector_index()
- if build_mode in (BuildMode.CLEAR_AND_IMPORT.value,
BuildMode.IMPORT_MODE.value):
- builder.commit_to_hugegraph()
- if build_mode != BuildMode.TEST_MODE.value:
- builder.build_vertex_id_semantic_index()
- log.warning("Current building mode: [%s]", build_mode)
- try:
- context = builder.run()
- return str(context)
- except Exception as e: # pylint: disable=broad-exception-caught
- log.error(e)
- raise gr.Error(str(e))
-
-
def test_api_connection(url, method="GET", headers=None, params=None,
body=None, auth=None, origin_call=None) -> int:
# TODO: use fastapi.request / starlette instead?
log.debug("Request URL: %s", url)
@@ -284,8 +208,9 @@ def apply_reranker_config(
elif reranker_option == "siliconflow":
settings.reranker_api_key = reranker_api_key
settings.reranker_model = reranker_model
+ from pyhugegraph.utils.constants import Constants
headers = {
- "accept": "application/json",
+ "accept": Constants.HEADER_CONTENT_TYPE,
"authorization": f"Bearer {reranker_api_key}",
}
status_code = test_api_connection(
@@ -346,192 +271,217 @@ def init_rag_ui() -> gr.Interface:
with gr.Blocks(
theme="default",
title="HugeGraph RAG Platform",
- css="footer {visibility: hidden}",
+ css=CSS,
) as hugegraph_llm_ui:
- gr.Markdown(
- """# HugeGraph LLM RAG Demo
- 1. Set up the HugeGraph server."""
- )
- with gr.Row():
- graph_config_input = [
- gr.Textbox(value=settings.graph_ip, label="ip"),
- gr.Textbox(value=settings.graph_port, label="port"),
- gr.Textbox(value=settings.graph_name, label="graph"),
- gr.Textbox(value=settings.graph_user, label="user"),
- gr.Textbox(value=settings.graph_pwd, label="pwd",
type="password"),
- gr.Textbox(value=settings.graph_space,
label="graphspace(Optional)"),
- ]
- graph_config_button = gr.Button("apply configuration")
-
+ gr.Markdown("# HugeGraph LLM RAG Demo")
+ with gr.Accordion("1. Set up the HugeGraph server.", open=False):
+ with gr.Row():
+ graph_config_input = [
+ gr.Textbox(value=settings.graph_ip, label="ip"),
+ gr.Textbox(value=settings.graph_port, label="port"),
+ gr.Textbox(value=settings.graph_name, label="graph"),
+ gr.Textbox(value=settings.graph_user, label="user"),
+ gr.Textbox(value=settings.graph_pwd, label="pwd",
type="password"),
+ gr.Textbox(value=settings.graph_space,
label="graphspace(Optional)"),
+ ]
+ graph_config_button = gr.Button("Apply config")
graph_config_button.click(apply_graph_config,
inputs=graph_config_input) # pylint: disable=no-member
- gr.Markdown("2. Set up the LLM.")
- llm_dropdown = gr.Dropdown(choices=["openai", "qianfan_wenxin",
"ollama"], value=settings.llm_type, label="LLM")
-
- @gr.render(inputs=[llm_dropdown])
- def llm_settings(llm_type):
- settings.llm_type = llm_type
- if llm_type == "openai":
- with gr.Row():
- llm_config_input = [
- gr.Textbox(value=settings.openai_api_key,
label="api_key", type="password"),
- gr.Textbox(value=settings.openai_api_base,
label="api_base"),
- gr.Textbox(value=settings.openai_language_model,
label="model_name"),
- gr.Textbox(value=settings.openai_max_tokens,
label="max_token"),
- ]
- elif llm_type == "ollama":
- with gr.Row():
- llm_config_input = [
- gr.Textbox(value=settings.ollama_host, label="host"),
- gr.Textbox(value=str(settings.ollama_port),
label="port"),
- gr.Textbox(value=settings.ollama_language_model,
label="model_name"),
- gr.Textbox(value="", visible=False),
- ]
- elif llm_type == "qianfan_wenxin":
- with gr.Row():
- llm_config_input = [
- gr.Textbox(value=settings.qianfan_api_key,
label="api_key", type="password"),
- gr.Textbox(value=settings.qianfan_secret_key,
label="secret_key", type="password"),
- gr.Textbox(value=settings.qianfan_language_model,
label="model_name"),
- gr.Textbox(value="", visible=False),
- ]
- # log.debug(llm_config_input)
- else:
- llm_config_input = []
- llm_config_button = gr.Button("apply configuration")
-
- llm_config_button.click(apply_llm_config, inputs=llm_config_input)
# pylint: disable=no-member
-
- gr.Markdown("3. Set up the Embedding.")
- embedding_dropdown = gr.Dropdown(
- choices=["openai", "qianfan_wenxin", "ollama"],
value=settings.embedding_type, label="Embedding"
- )
-
- @gr.render(inputs=[embedding_dropdown])
- def embedding_settings(embedding_type):
- settings.embedding_type = embedding_type
- if embedding_type == "openai":
- with gr.Row():
- embedding_config_input = [
- gr.Textbox(value=settings.openai_api_key,
label="api_key", type="password"),
- gr.Textbox(value=settings.openai_api_base,
label="api_base"),
- gr.Textbox(value=settings.openai_embedding_model,
label="model_name"),
- ]
- elif embedding_type == "qianfan_wenxin":
- with gr.Row():
- embedding_config_input = [
- gr.Textbox(value=settings.qianfan_api_key,
label="api_key", type="password"),
- gr.Textbox(value=settings.qianfan_secret_key,
label="secret_key", type="password"),
- gr.Textbox(value=settings.qianfan_embedding_model,
label="model_name"),
- ]
- elif embedding_type == "ollama":
- with gr.Row():
- embedding_config_input = [
- gr.Textbox(value=settings.ollama_host, label="host"),
- gr.Textbox(value=str(settings.ollama_port),
label="port"),
- gr.Textbox(value=settings.ollama_embedding_model,
label="model_name"),
- ]
- else:
- embedding_config_input = []
+ with gr.Accordion("2. Set up the LLM.", open=False):
+ llm_dropdown = gr.Dropdown(choices=["openai", "qianfan_wenxin",
"ollama"],
+ value=settings.llm_type, label="LLM")
- embedding_config_button = gr.Button("apply configuration")
-
- # Call the separate apply_embedding_configuration function here
- embedding_config_button.click( # pylint: disable=no-member
- fn=apply_embedding_config,
- inputs=embedding_config_input, # pylint: disable=no-member
+ @gr.render(inputs=[llm_dropdown])
+ def llm_settings(llm_type):
+ settings.llm_type = llm_type
+ if llm_type == "openai":
+ with gr.Row():
+ llm_config_input = [
+ gr.Textbox(value=settings.openai_api_key,
label="api_key", type="password"),
+ gr.Textbox(value=settings.openai_api_base,
label="api_base"),
+ gr.Textbox(value=settings.openai_language_model,
label="model_name"),
+ gr.Textbox(value=settings.openai_max_tokens,
label="max_token"),
+ ]
+ elif llm_type == "ollama":
+ with gr.Row():
+ llm_config_input = [
+ gr.Textbox(value=settings.ollama_host,
label="host"),
+ gr.Textbox(value=str(settings.ollama_port),
label="port"),
+ gr.Textbox(value=settings.ollama_language_model,
label="model_name"),
+ gr.Textbox(value="", visible=False),
+ ]
+ elif llm_type == "qianfan_wenxin":
+ with gr.Row():
+ llm_config_input = [
+ gr.Textbox(value=settings.qianfan_api_key,
label="api_key", type="password"),
+ gr.Textbox(value=settings.qianfan_secret_key,
label="secret_key", type="password"),
+ gr.Textbox(value=settings.qianfan_language_model,
label="model_name"),
+ gr.Textbox(value="", visible=False),
+ ]
+ else:
+ llm_config_input = []
+ llm_config_button = gr.Button("apply configuration")
+ llm_config_button.click(apply_llm_config,
inputs=llm_config_input) # pylint: disable=no-member
+
+ with gr.Accordion("3. Set up the Embedding.", open=False):
+ embedding_dropdown = gr.Dropdown(
+ choices=["openai", "qianfan_wenxin", "ollama"],
value=settings.embedding_type, label="Embedding"
)
- gr.Markdown("4. Set up the Reranker (Optional).")
- reranker_dropdown = gr.Dropdown(
- choices=["cohere", "siliconflow", ("default/offline", "None")],
- value=os.getenv("reranker_type") or "None",
- label="Reranker",
- )
-
- @gr.render(inputs=[reranker_dropdown])
- def reranker_settings(reranker_type):
- settings.reranker_type = reranker_type if reranker_type != "None"
else None
- if reranker_type == "cohere":
- with gr.Row():
- reranker_config_input = [
- gr.Textbox(value=settings.reranker_api_key,
label="api_key", type="password"),
- gr.Textbox(value=settings.reranker_model,
label="model"),
- gr.Textbox(value=settings.cohere_base_url,
label="base_url"),
- ]
- elif reranker_type == "siliconflow":
- with gr.Row():
- reranker_config_input = [
- gr.Textbox(value=settings.reranker_api_key,
label="api_key", type="password"),
- gr.Textbox(
- value="BAAI/bge-reranker-v2-m3",
- label="model",
- info="Please refer to
https://siliconflow.cn/pricing",
- ),
- ]
- else:
- reranker_config_input = []
-
- reranker_config_button = gr.Button("apply configuration")
+ @gr.render(inputs=[embedding_dropdown])
+ def embedding_settings(embedding_type):
+ settings.embedding_type = embedding_type
+ if embedding_type == "openai":
+ with gr.Row():
+ embedding_config_input = [
+ gr.Textbox(value=settings.openai_api_key,
label="api_key", type="password"),
+ gr.Textbox(value=settings.openai_api_base,
label="api_base"),
+ gr.Textbox(value=settings.openai_embedding_model,
label="model_name"),
+ ]
+ elif embedding_type == "qianfan_wenxin":
+ with gr.Row():
+ embedding_config_input = [
+ gr.Textbox(value=settings.qianfan_api_key,
label="api_key", type="password"),
+ gr.Textbox(value=settings.qianfan_secret_key,
label="secret_key", type="password"),
+ gr.Textbox(value=settings.qianfan_embedding_model,
label="model_name"),
+ ]
+ elif embedding_type == "ollama":
+ with gr.Row():
+ embedding_config_input = [
+ gr.Textbox(value=settings.ollama_host,
label="host"),
+ gr.Textbox(value=str(settings.ollama_port),
label="port"),
+ gr.Textbox(value=settings.ollama_embedding_model,
label="model_name"),
+ ]
+ else:
+ embedding_config_input = []
+
+ embedding_config_button = gr.Button("apply configuration")
+
+ # Call the separate apply_embedding_configuration function here
+ embedding_config_button.click( # pylint: disable=no-member
+ fn=apply_embedding_config,
+ inputs=embedding_config_input, # pylint: disable=no-member
+ )
- # TODO: use "gr.update()" or other way to update the config in
time (refactor the click event)
- # Call the separate apply_reranker_configuration function here
- reranker_config_button.click( # pylint: disable=no-member
- fn=apply_reranker_config,
- inputs=reranker_config_input, # pylint: disable=no-member
+ with gr.Accordion("4. Set up the Reranker.", open=False):
+ reranker_dropdown = gr.Dropdown(
+ choices=["cohere", "siliconflow", ("default/offline", "None")],
+ value=os.getenv("reranker_type") or "None",
+ label="Reranker",
)
+ @gr.render(inputs=[reranker_dropdown])
+ def reranker_settings(reranker_type):
+ settings.reranker_type = reranker_type if reranker_type !=
"None" else None
+ if reranker_type == "cohere":
+ with gr.Row():
+ reranker_config_input = [
+ gr.Textbox(value=settings.reranker_api_key,
label="api_key", type="password"),
+ gr.Textbox(value=settings.reranker_model,
label="model"),
+ gr.Textbox(value=settings.cohere_base_url,
label="base_url"),
+ ]
+ elif reranker_type == "siliconflow":
+ with gr.Row():
+ reranker_config_input = [
+ gr.Textbox(value=settings.reranker_api_key,
label="api_key", type="password"),
+ gr.Textbox(
+ value="BAAI/bge-reranker-v2-m3",
+ label="model",
+ info="Please refer to
https://siliconflow.cn/pricing",
+ ),
+ ]
+ else:
+ reranker_config_input = []
+ reranker_config_button = gr.Button("apply configuration")
+
+ # TODO: use "gr.update()" or other way to update the config in
time (refactor the click event)
+ # Call the separate apply_reranker_configuration function here
+ reranker_config_button.click( # pylint: disable=no-member
+ fn=apply_reranker_config,
+ inputs=reranker_config_input, # pylint: disable=no-member
+ )
+
gr.Markdown(
"""## 1. Build vector/graph RAG (💡)
-- Doc(s): Upload document file(s) which should be TXT or DOCX. (Multiple files
can be selected together)
+- Doc(s):
+ - text: Build index from plain text.
+ - file: Upload document file(s) which should be TXT or DOCX. (Multiple
files can be selected together)
- Schema: Accepts two types of text as below:
- User-defined JSON format Schema.
- Specify the name of the HugeGraph graph instance, it will automatically
get the schema from it.
- Info extract head: The head of prompt of info extracting.
-- Build mode:
- - Test Mode: Only extract vertices and edges from the file into memory
(without building the vector index or
- writing data into HugeGraph)
- - Import Mode: Extract the data and append it to HugeGraph & the vector
index (without clearing any existing data)
- - Clear and Import: Clear all existed RAG data(vector + graph), then
rebuild them from the current input
- - Rebuild Vector: Only rebuild vector index. (keep the graph data intact)
"""
)
- schema = prompt.rag_schema
+ schema = prompt.graph_schema
with gr.Row():
- input_file = gr.File(
- value=[os.path.join(resource_path, "demo", "test.txt")],
- label="Docs (multi-files can be selected together)",
- file_count="multiple",
- )
- input_schema = gr.Textbox(value=schema, label="Schema", lines=2)
- info_extract_template = gr.Textbox(value=SCHEMA_EXAMPLE_PROMPT,
label="Info extract head", lines=2)
with gr.Column():
- mode = gr.Radio(
- choices=["Test Mode", "Import Mode", "Clear and Import",
"Rebuild Vector"],
- value="Test Mode",
- label="Build mode",
- )
- btn = gr.Button("Build Vector/Graph RAG")
+ with gr.Tab("text") as tab_upload_text:
+ input_text = gr.Textbox(value="", label="Doc(s)",
lines=20, show_copy_button=True)
+ with gr.Tab("file") as tab_upload_file:
+ input_file = gr.File(
+ value=[os.path.join(resource_path, "demo",
"test.txt")],
+ label="Docs (multi-files can be selected together)",
+ file_count="multiple",
+ )
+ input_schema = gr.Textbox(value=schema, label="Schema", lines=15,
show_copy_button=True)
+ info_extract_template = gr.Textbox(value=SCHEMA_EXAMPLE_PROMPT,
label="Info extract head", lines=15,
+ show_copy_button=True)
+ out = gr.Code(label="Output", language="json",
elem_classes="code-container-edit")
+
with gr.Row():
- out = gr.Textbox(label="Output", show_copy_button=True)
- btn.click( # pylint: disable=no-member
- fn=build_kg,
- inputs=[input_file, input_schema, info_extract_template, mode],
- outputs=out,
+ with gr.Accordion("Get RAG Info", open=False):
+ with gr.Column():
+ vector_index_btn0 = gr.Button("Get Vector Index Info",
size="sm")
+ graph_index_btn0 = gr.Button("Get Graph Index Info",
size="sm")
+ with gr.Accordion("Clear RAG Info", open=False):
+ with gr.Column():
+ vector_index_btn1 = gr.Button("Clear Vector Index",
size="sm")
+ graph_index_btn1 = gr.Button("Clear Graph Data & Index",
size="sm")
+
+ vector_import_bt = gr.Button("Import into Vector",
variant="primary")
+ graph_index_rebuild_bt = gr.Button("Rebuild vid Index")
+ graph_extract_bt = gr.Button("Extract Graph Data (1)",
variant="primary")
+ graph_loading_bt = gr.Button("Load into GraphDB (2)",
interactive=True)
+
+ vector_index_btn0.click(get_vector_index_info, outputs=out) # pylint:
disable=no-member
+ vector_index_btn1.click(clean_vector_index) # pylint:
disable=no-member
+ vector_import_bt.click(build_vector_index, inputs=[input_file,
input_text], outputs=out) # pylint: disable=no-member
+ graph_index_btn0.click(get_graph_index_info, outputs=out) # pylint:
disable=no-member
+ graph_index_btn1.click(clean_all_graph_index) # pylint:
disable=no-member
+ graph_index_rebuild_bt.click(fit_vid_index, outputs=out) # pylint:
disable=no-member
+
+ # origin_out = gr.Textbox(visible=False)
+ graph_extract_bt.click( # pylint: disable=no-member
+ extract_graph,
+ inputs=[input_file, input_text, input_schema,
info_extract_template],
+ outputs=[out]
)
+ graph_loading_bt.click(import_graph_data, inputs=[out, input_schema],
outputs=[out]) # pylint: disable=no-member
+
+
+ def on_tab_select(input_f, input_t, evt: gr.SelectData):
+ print(f"You selected {evt.value} at {evt.index} from {evt.target}")
+ if evt.value == "file":
+ return input_f, ""
+ if evt.value == "text":
+ return [], input_t
+ return [], ""
+ tab_upload_file.select(fn=on_tab_select, inputs=[input_file,
input_text], outputs=[input_file, input_text]) # pylint: disable=no-member
+ tab_upload_text.select(fn=on_tab_select, inputs=[input_file,
input_text], outputs=[input_file, input_text]) # pylint: disable=no-member
+
+
gr.Markdown("""## 2. RAG with HugeGraph 📖""")
with gr.Row():
with gr.Column(scale=2):
- inp = gr.Textbox(value=prompt.question, label="Question",
show_copy_button=True, lines=2)
+ inp = gr.Textbox(value=prompt.default_question,
label="Question", show_copy_button=True, lines=2)
raw_out = gr.Textbox(label="Basic LLM Answer",
show_copy_button=True)
vector_only_out = gr.Textbox(label="Vector-only Answer",
show_copy_button=True)
graph_only_out = gr.Textbox(label="Graph-only Answer",
show_copy_button=True)
graph_vector_out = gr.Textbox(label="Graph-Vector Answer",
show_copy_button=True)
from hugegraph_llm.operators.llm_op.answer_synthesize import
DEFAULT_ANSWER_TEMPLATE
-
answer_prompt_input = gr.Textbox(
value=DEFAULT_ANSWER_TEMPLATE, label="Custom Prompt",
show_copy_button=True, lines=2
)
@@ -556,14 +506,14 @@ def init_rag_ui() -> gr.Interface:
)
graph_ratio = gr.Slider(0, 1, 0.5, label="Graph
Ratio", step=0.1, interactive=False)
- graph_vector_radio.change(toggle_slider,
inputs=graph_vector_radio, outputs=graph_ratio)
+ graph_vector_radio.change(toggle_slider,
inputs=graph_vector_radio, outputs=graph_ratio) # pylint: disable=no-member
near_neighbor_first = gr.Checkbox(
value=False,
label="Near neighbor first(Optional)",
info="One-depth neighbors > two-depth neighbors",
)
custom_related_information = gr.Text(
- prompt.custom_related_information,
+ prompt.custom_rerank_info,
label="Custom related information(Optional)",
)
btn = gr.Button("Answer Question", variant="primary")
@@ -694,18 +644,18 @@ def init_rag_ui() -> gr.Interface:
gr.Markdown("""## 4. Others (🚧) """)
with gr.Row():
- with gr.Column():
- inp = gr.Textbox(value="g.V().limit(10)", label="Gremlin
query", show_copy_button=True)
- fmt = gr.Checkbox(label="Format JSON", value=True)
- out = gr.Textbox(label="Output", show_copy_button=True)
- btn = gr.Button("Run gremlin query on HugeGraph")
- btn.click(fn=run_gremlin_query, inputs=[inp, fmt], outputs=out) #
pylint: disable=no-member
-
- with gr.Row():
- inp = []
- out = gr.Textbox(label="Output", show_copy_button=True)
- btn = gr.Button("(BETA) Init HugeGraph test data (🚧WIP)")
- btn.click(fn=init_hg_test_data, inputs=inp, outputs=out) # pylint:
disable=no-member
+ inp = gr.Textbox(value="g.V().limit(10)", label="Gremlin
query", show_copy_button=True, lines=8)
+ out = gr.Code(label="Output", language="json",
elem_classes="code-container-show")
+ btn = gr.Button("Run Gremlin query")
+ btn.click(fn=run_gremlin_query, inputs=[inp], outputs=out) # pylint:
disable=no-member
+
+ gr.Markdown("---")
+ with gr.Accordion("Init HugeGraph test data (🚧)", open=False):
+ with gr.Row():
+ inp = []
+ out = gr.Textbox(label="Init Graph Demo Result",
show_copy_button=True)
+ btn = gr.Button("(BETA) Init HugeGraph test data (🚧)")
+ btn.click(fn=init_hg_test_data, inputs=inp, outputs=out) #
pylint: disable=no-member
return hugegraph_llm_ui
diff --git a/hugegraph-llm/src/hugegraph_llm/indices/vector_index.py
b/hugegraph-llm/src/hugegraph_llm/indices/vector_index.py
index f073322..949d42f 100644
--- a/hugegraph-llm/src/hugegraph_llm/indices/vector_index.py
+++ b/hugegraph-llm/src/hugegraph_llm/indices/vector_index.py
@@ -14,24 +14,34 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
-
+import logging
+import os
import pickle as pkl
from copy import deepcopy
-from typing import List, Dict, Any
+from typing import List, Dict, Any, Set, Union
import faiss
import numpy as np
+from hugegraph_llm.utils.log import log
+
+INDEX_FILE_NAME = "index.faiss"
+PROPERTIES_FILE_NAME = "properties.pkl"
+
class VectorIndex:
"""Comment"""
- def __init__(self, embed_dim: int):
+ def __init__(self, embed_dim: int = 1024):
self.index = faiss.IndexFlatL2(embed_dim)
self.properties = []
@staticmethod
- def from_index_file(index_file: str, properties_file: str) ->
"VectorIndex":
+ def from_index_file(dir_path: str) -> "VectorIndex":
+ index_file = os.path.join(dir_path, INDEX_FILE_NAME)
+ properties_file = os.path.join(dir_path, PROPERTIES_FILE_NAME)
+ if not os.path.exists(index_file) or not
os.path.exists(properties_file):
+ log.warning("No index file found, create a new one.")
+ return VectorIndex()
faiss_index = faiss.read_index(index_file)
embed_dim = faiss_index.d
with open(properties_file, "rb") as f:
@@ -41,18 +51,52 @@ class VectorIndex:
vector_index.properties = properties
return vector_index
- def to_index_file(self, index_file: str, properties_file: str):
+ def to_index_file(self, dir_path: str):
+ if not os.path.exists(dir_path):
+ os.makedirs(dir_path)
+ index_file = os.path.join(dir_path, INDEX_FILE_NAME)
+ properties_file = os.path.join(dir_path, PROPERTIES_FILE_NAME)
faiss.write_index(self.index, index_file)
with open(properties_file, "wb") as f:
pkl.dump(self.properties, f)
def add(self, vectors: List[List[float]], props: List[Any]):
+ if len(vectors) == 0:
+ return
+ if self.index.ntotal == 0 and len(vectors[0]) != self.index.d:
+ self.index = faiss.IndexFlatL2(len(vectors[0]))
self.index.add(np.array(vectors))
self.properties.extend(props)
+ def remove(self, props: Union[Set[Any], List[Any]]) -> int:
+ if isinstance(props, list):
+ props = set(props)
+ indices = []
+ remove_num = 0
+ for i, p in enumerate(self.properties):
+ if p in props:
+ indices.append(i)
+ remove_num += 1
+ self.index.remove_ids(np.array(indices))
+ self.properties = [p for i, p in enumerate(self.properties) if i not
in indices]
+ return remove_num
+
def search(self, query_vector: List[float], top_k: int) -> List[Dict[str,
Any]]:
+ if self.index.ntotal == 0:
+ return []
+ if len(query_vector) != self.index.d:
+ raise ValueError("Query vector dimension does not match index
dimension!")
_, indices = self.index.search(np.array([query_vector]), top_k)
results = []
for i in indices[0]:
results.append(deepcopy(self.properties[i]))
return results
+
+ @staticmethod
+ def clean(dir_path: str):
+ index_file = os.path.join(dir_path, INDEX_FILE_NAME)
+ properties_file = os.path.join(dir_path, PROPERTIES_FILE_NAME)
+ if os.path.exists(index_file):
+ os.remove(index_file)
+ if os.path.exists(properties_file):
+ os.remove(properties_file)
diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/ollama.py
b/hugegraph-llm/src/hugegraph_llm/models/llms/ollama.py
index dfd4669..b7b0148 100644
--- a/hugegraph-llm/src/hugegraph_llm/models/llms/ollama.py
+++ b/hugegraph-llm/src/hugegraph_llm/models/llms/ollama.py
@@ -16,6 +16,7 @@
# under the License.
+import json
from typing import Any, List, Optional, Callable, Dict
import ollama
@@ -23,7 +24,6 @@ from retry import retry
from hugegraph_llm.models.llms.base import BaseLLM
from hugegraph_llm.utils.log import log
-import json
class OllamaClient(BaseLLM):
diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py
b/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py
index 9d1f8ed..30ac805 100644
--- a/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py
+++ b/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py
@@ -14,9 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+
import json
import os
-import re
from typing import Callable, List, Optional, Dict, Any
import openai
diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/qianfan.py
b/hugegraph-llm/src/hugegraph_llm/models/llms/qianfan.py
index c03dfb2..eeaf938 100644
--- a/hugegraph-llm/src/hugegraph_llm/models/llms/qianfan.py
+++ b/hugegraph-llm/src/hugegraph_llm/models/llms/qianfan.py
@@ -16,6 +16,7 @@
# under the License.
import os
+import json
from typing import Optional, List, Dict, Any, Callable
import qianfan
@@ -23,7 +24,6 @@ from retry import retry
from hugegraph_llm.models.llms.base import BaseLLM
from hugegraph_llm.utils.log import log
-import json
class QianfanClient(BaseLLM):
diff --git a/hugegraph-llm/src/hugegraph_llm/models/rerankers/cohere.py
b/hugegraph-llm/src/hugegraph_llm/models/rerankers/cohere.py
index b552717..99aa60d 100644
--- a/hugegraph-llm/src/hugegraph_llm/models/rerankers/cohere.py
+++ b/hugegraph-llm/src/hugegraph_llm/models/rerankers/cohere.py
@@ -35,14 +35,15 @@ class CohereReranker:
if not top_n:
top_n = len(documents)
assert top_n <= len(documents), "'top_n' should be less than or equal
to the number of documents"
-
+
if top_n == 0:
return []
url = self.base_url
+ from pyhugegraph.utils.constants import Constants
headers = {
- "accept": "application/json",
- "content-type": "application/json",
+ "accept": Constants.HEADER_CONTENT_TYPE,
+ "content-type": Constants.HEADER_CONTENT_TYPE,
"Authorization": f"Bearer {self.api_key}",
}
payload = {
@@ -51,7 +52,7 @@ class CohereReranker:
"top_n": top_n,
"documents": documents,
}
- response = requests.post(url, headers=headers, json=payload)
+ response = requests.post(url, headers=headers, json=payload,
timeout=60)
response.raise_for_status() # Raise an error for bad status codes
results = response.json()["results"]
sorted_docs = [documents[item["index"]] for item in results]
diff --git a/hugegraph-llm/src/hugegraph_llm/models/rerankers/init_reranker.py
b/hugegraph-llm/src/hugegraph_llm/models/rerankers/init_reranker.py
index 541f413..f80cd27 100644
--- a/hugegraph-llm/src/hugegraph_llm/models/rerankers/init_reranker.py
+++ b/hugegraph-llm/src/hugegraph_llm/models/rerankers/init_reranker.py
@@ -29,7 +29,6 @@ class Rerankers:
return CohereReranker(
api_key=settings.reranker_api_key,
base_url=settings.cohere_base_url, model=settings.reranker_model
)
- elif self.reranker_type == "siliconflow":
+ if self.reranker_type == "siliconflow":
return SiliconReranker(api_key=settings.reranker_api_key,
model=settings.reranker_model)
- else:
- raise Exception(f"reranker type is not supported !")
+ raise Exception("Reranker type is not supported!")
diff --git a/hugegraph-llm/src/hugegraph_llm/models/rerankers/siliconflow.py
b/hugegraph-llm/src/hugegraph_llm/models/rerankers/siliconflow.py
index a860a84..b4fa14c 100644
--- a/hugegraph-llm/src/hugegraph_llm/models/rerankers/siliconflow.py
+++ b/hugegraph-llm/src/hugegraph_llm/models/rerankers/siliconflow.py
@@ -47,12 +47,13 @@ class SiliconReranker:
"overlap_tokens": 80,
"top_n": top_n,
}
+ from pyhugegraph.utils.constants import Constants
headers = {
- "accept": "application/json",
- "content-type": "application/json",
+ "accept": Constants.HEADER_CONTENT_TYPE,
+ "content-type": Constants.HEADER_CONTENT_TYPE,
"authorization": f"Bearer {self.api_key}",
}
- response = requests.post(url, json=payload, headers=headers)
+ response = requests.post(url, json=payload, headers=headers,
timeout=60)
response.raise_for_status() # Raise an error for bad status codes
results = response.json()["results"]
sorted_docs = [documents[item["index"]] for item in results]
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/common_op/check_schema.py
b/hugegraph-llm/src/hugegraph_llm/operators/common_op/check_schema.py
index 7a1f64a..10ed2dc 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/common_op/check_schema.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/common_op/check_schema.py
@@ -16,7 +16,7 @@
# under the License.
-from typing import Any
+from typing import Any, Optional, Dict
from hugegraph_llm.utils.log import log
@@ -25,8 +25,10 @@ class CheckSchema:
self.result = None
self.data = data
- def run(self, schema=None) -> Any: # pylint: disable=too-many-branches
- schema = self.data or schema
+ def run(self, context: Optional[Dict[str, Any]] = None) -> Any: # pylint:
disable=too-many-branches
+ if context is None:
+ context = {}
+ schema = self.data or context.get("schema")
if not isinstance(schema, dict):
raise ValueError("Input data is not a dictionary.")
if "vertexlabels" not in schema or "edgelabels" not in schema:
@@ -93,4 +95,5 @@ class CheckSchema:
"'name', 'source_label', 'target_label' "
"in edge is not of correct type."
)
- return {"schema": schema}
+ context.update({"schema": schema})
+ return context
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py
b/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py
index 6e356e2..c4ff757 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py
@@ -58,7 +58,7 @@ class MergeDedupRerank:
self.near_neighbor_first = near_neighbor_first
self.custom_related_information = custom_related_information
if priority:
- raise ValueError(f"Unimplemented rerank strategy: priority.")
+ raise ValueError("Unimplemented rerank strategy: priority.")
self.switch_to_bleu = False
def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
@@ -106,6 +106,7 @@ class MergeDedupRerank:
if self.method == "reranker":
reranker = Rerankers().get_reranker()
return reranker.get_rerank_lists(query, results, topn)
+ raise ValueError(f"Unimplemented rerank method '{self.method}'.")
def _rerank_with_vertex_degree(
self,
@@ -125,7 +126,7 @@ class MergeDedupRerank:
reranker.get_rerank_lists(query, vertex_degree) + [""] for
vertex_degree in vertex_degree_list
]
except requests.exceptions.RequestException as e:
- log.warning(f"Online reranker fails, automatically switches to
local bleu method: {e}")
+ log.warning("Online reranker fails, automatically switches to
local bleu method: %s", e)
self.method = "bleu"
self.switch_to_bleu = True
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/commit_to_hugegraph.py
b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/commit_to_hugegraph.py
index cf252e2..c28d1b4 100644
---
a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/commit_to_hugegraph.py
+++
b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/commit_to_hugegraph.py
@@ -37,48 +37,81 @@ class CommitToKg:
self.schema = self.client.schema()
def run(self, data: dict) -> Dict[str, Any]:
- if "schema" not in data:
- self.schema_free_mode(data["triples"])
+ schema = data.get("schema")
+ vertices = data.get("vertices", [])
+ edges = data.get("edges", [])
+
+ if not vertices and not edges:
+ log.critical("(Loading) Both vertices and edges are empty. Please
check the input data again.")
+ raise ValueError("Both vertices and edges input are empty.")
+
+ if not schema:
+ # TODO: ensure the function works correctly (update the logic
later)
+ self.schema_free_mode(data.get("triples", []))
+ log.warning("Using schema_free mode, could try schema_define mode
for better effect!")
else:
- schema = data["schema"]
- vertices = data["vertices"]
- edges = data["edges"]
- self.init_schema(schema)
- self.init_graph(vertices, edges, schema)
+ self.init_schema_if_need(schema)
+ self.load_into_graph(vertices, edges, schema)
return data
- def init_graph(self, vertices, edges, schema):
- key_map = {}
- for vertex in schema["vertexlabels"]:
- key_map[vertex["name"]] = vertex
+ def load_into_graph(self, vertices, edges, schema):
+ vertex_label_map = {v_label["name"]: v_label for v_label in
schema["vertexlabels"]}
+ edge_label_map = {e_label["name"]: e_label for e_label in
schema["edgelabels"]}
+
for vertex in vertices:
- label = vertex["label"]
- properties = vertex["properties"]
- for pk in key_map[label]["primary_keys"]:
- if pk not in properties:
- properties[pk] = "NULL"
- for uk in key_map[label]["nullable_keys"]:
- if uk not in properties:
- properties[uk] = "NULL"
+ input_label = vertex["label"]
+ # 1. ensure the input_label in the graph schema
+ if input_label not in vertex_label_map:
+ log.critical("(Input) VertexLabel %s not found in schema, skip
& need check it!", input_label)
+ continue
+
+ input_properties = vertex["properties"]
+ vertex_label = vertex_label_map[input_label]
+ primary_keys = vertex_label["primary_keys"]
+ nullable_keys = vertex_label.get("nullable_keys", [])
+ non_null_keys = [key for key in vertex_label["properties"] if key
not in nullable_keys]
+
+ # 2. Handle primary-keys mode vertex
+ for pk in primary_keys:
+ if not input_properties.get(pk):
+ if len(primary_keys) == 1:
+ log.error("Primary-key '%s' missing in vertex %s, skip
it & need check it again", pk, vertex)
+ continue
+ input_properties[pk] = "null" # FIXME: handle
bool/number/date type
+ log.warning("Primary-key '%s' missing in vertex %s, mark
empty & need check it again!", pk, vertex)
+
+ # 3. Ensure all non-nullable props are set
+ for key in non_null_keys:
+ if key not in input_properties:
+ input_properties[key] = "" # FIXME: handle
bool/number/date type
+ log.warning("Property '%s' missing in vertex %s, set to ''
for now", key, vertex)
try:
- vid = self.client.graph().addVertex(label, properties).id
+ # TODO: we could try batch add vertices first, setback to
single-mode if failed
+ vid = self.client.graph().addVertex(input_label,
input_properties).id
vertex["id"] = vid
except NotFoundError as e:
- print(e)
+ log.error(e)
+ except CreateError as e:
+ log.error("Error on creating vertex: %s, %s", vertex, e)
+
for edge in edges:
start = edge["outV"]
end = edge["inV"]
label = edge["label"]
properties = edge["properties"]
+
+ if label not in edge_label_map:
+ log.critical("(Input) EdgeLabel %s not found in schema, skip &
need check it!", label)
+ continue
try:
+ # TODO: we could try batch add edges first, setback to
single-mode if failed
self.client.graph().addEdge(label, start, end, properties)
except NotFoundError as e:
- print(e)
+ log.error(e)
except CreateError as e:
- log.error("Error on creating edge: %s", str(edge))
- print(e)
+ log.error("Error on creating edge: %s, %s", edge, e)
- def init_schema(self, schema):
+ def init_schema_if_need(self, schema: object):
vertices = schema["vertexlabels"]
edges = schema["edgelabels"]
@@ -92,6 +125,7 @@ class CommitToKg:
self.schema.vertexLabel(vertex_label).properties(*properties).nullableKeys(
*nullable_keys
).usePrimaryKeyId().primaryKeys(*primary_keys).ifNotExist().create()
+
for edge in edges:
edge_label = edge["name"]
source_vertex_label = edge["source_label"]
@@ -105,19 +139,13 @@ class CommitToKg:
def schema_free_mode(self, data):
self.schema.propertyKey("name").asText().ifNotExist().create()
- self.schema.vertexLabel("vertex").useCustomizeStringId().properties(
- "name"
- ).ifNotExist().create()
+
self.schema.vertexLabel("vertex").useCustomizeStringId().properties("name").ifNotExist().create()
self.schema.edgeLabel("edge").sourceLabel("vertex").targetLabel("vertex").properties(
"name"
).ifNotExist().create()
- self.schema.indexLabel("vertexByName").onV("vertex").by(
- "name"
- ).secondary().ifNotExist().create()
- self.schema.indexLabel("edgeByName").onE("edge").by(
- "name"
- ).secondary().ifNotExist().create()
+
self.schema.indexLabel("vertexByName").onV("vertex").by("name").secondary().ifNotExist().create()
+
self.schema.indexLabel("edgeByName").onE("edge").by("name").secondary().ifNotExist().create()
for item in data:
s, p, o = (element.strip() for element in item)
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/fetch_graph_data.py
b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/fetch_graph_data.py
index c3dfc25..77233f0 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/fetch_graph_data.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/fetch_graph_data.py
@@ -26,6 +26,8 @@ class FetchGraphData:
self.graph = graph
def run(self, context: Optional[Dict[str, Any]]) -> Dict[str, Any]:
+ if context is None:
+ context = {}
if "vertices" not in context:
context["vertices"] = []
vertices = self.graph.gremlin().exec("g.V()")["data"]
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py
b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py
index fe225c2..2f08f11 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py
@@ -119,13 +119,16 @@ class GraphRAGQuery:
edge_labels_str = ",".join("'" + label + "'" for label in edge_labels)
use_id_to_match = self._prop_to_match is None
+ if use_id_to_match:
+ if not entrance_vids:
+ return context
- if not use_id_to_match:
- assert keywords is not None, "No keywords for graph query."
- keywords_str = ",".join("'" + kw + "'" for kw in keywords)
- rag_gremlin_query = self.PROP_RAG_GREMLIN_QUERY_TEMPL.format(
- prop=self._prop_to_match,
- keywords=keywords_str,
+ rag_gremlin_query =
self.VERTEX_GREMLIN_QUERY_TEMPL.format(keywords=entrance_vids)
+ result: List[Any] =
self._client.gremlin().exec(gremlin=rag_gremlin_query)["data"]
+
+ vertex_knowledge =
self._format_knowledge_from_vertex(query_result=result)
+ rag_gremlin_query = self.ID_RAG_GREMLIN_QUERY_TEMPL.format(
+ keywords=entrance_vids,
max_deep=self._max_deep,
max_items=self._max_items,
edge_labels=edge_labels_str,
@@ -134,15 +137,17 @@ class GraphRAGQuery:
graph_chain_knowledge, vertex_degree_list, knowledge_with_degree =
self._format_knowledge_from_query_result(
query_result=result
)
+ graph_chain_knowledge.update(vertex_knowledge)
+ if vertex_degree_list:
+ vertex_degree_list[0].update(vertex_knowledge)
+ else:
+ vertex_degree_list.append(vertex_knowledge)
else:
- assert entrance_vids is not None, "No entrance vertices for query."
- rag_gremlin_query = self.VERTEX_GREMLIN_QUERY_TEMPL.format(
- keywords=entrance_vids,
- )
- result: List[Any] =
self._client.gremlin().exec(gremlin=rag_gremlin_query)["data"]
- vertex_knowledge =
self._format_knowledge_from_vertex(query_result=result)
- rag_gremlin_query = self.ID_RAG_GREMLIN_QUERY_TEMPL.format(
- keywords=entrance_vids,
+ assert keywords, "No related property(keywords) for graph query."
+ keywords_str = ",".join("'" + kw + "'" for kw in keywords)
+ rag_gremlin_query = self.PROP_RAG_GREMLIN_QUERY_TEMPL.format(
+ prop=self._prop_to_match,
+ keywords=keywords_str,
max_deep=self._max_deep,
max_items=self._max_items,
edge_labels=edge_labels_str,
@@ -151,8 +156,6 @@ class GraphRAGQuery:
graph_chain_knowledge, vertex_degree_list, knowledge_with_degree =
self._format_knowledge_from_query_result(
query_result=result
)
- graph_chain_knowledge.update(vertex_knowledge)
- vertex_degree_list[0].update(vertex_knowledge)
context["graph_result"] = list(graph_chain_knowledge)
context["vertex_degree_list"] = [list(vertex_degree) for vertex_degree
in vertex_degree_list]
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/schema_manager.py
b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/schema_manager.py
index 5c002ae..a61063f 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/schema_manager.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/schema_manager.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
+from typing import Dict, Any, Optional
from hugegraph_llm.config import settings
from pyhugegraph.client import PyHugeClient
@@ -33,7 +33,11 @@ class SchemaManager:
)
self.schema = self.client.schema()
- def run(self):
+ # FIXME: This method is not working as expected
+ # def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
+ def run(self, context: Optional[Dict[str, Any]]) -> Dict[str, Any]:
+ if context is None:
+ context = {}
schema = self.schema.getSchema()
vertices = []
for vl in schema["vertexlabels"]:
@@ -50,4 +54,6 @@ class SchemaManager:
edges.append(edge)
if not vertices and not edges:
raise Exception(f"Can not get {self.graph_name}'s schema from
HugeGraph!")
- return {"vertices": vertices, "edges": edges}
+
+ context.update({"schema": schema})
+ return context
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_gremlin_example_index.py
b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_gremlin_example_index.py
index b026a4c..4e15274 100644
---
a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_gremlin_example_index.py
+++
b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_gremlin_example_index.py
@@ -26,11 +26,7 @@ from hugegraph_llm.indices.vector_index import VectorIndex
class BuildGremlinExampleIndex:
def __init__(self, embedding: BaseEmbedding, examples: List[Dict[str,
str]]):
- example_dir_name = ".gremlin_examples"
- self.index_file = os.path.join(resource_path, example_dir_name,
"index.faiss")
- self.content_file = os.path.join(resource_path, example_dir_name,
"properties.pkl")
- if not os.path.exists(os.path.join(resource_path, example_dir_name)):
- os.mkdir(os.path.join(resource_path, example_dir_name))
+ self.index_dir = os.path.join(resource_path, "gremlin_examples")
self.examples = examples
self.embedding = embedding
@@ -42,6 +38,6 @@ class BuildGremlinExampleIndex:
if len(self.examples) > 0:
vector_index = VectorIndex(embed_dim)
vector_index.add(examples_embedding, self.examples)
- vector_index.to_index_file(str(self.index_file),
str(self.content_file))
+ vector_index.to_index_file(self.index_dir)
context["embed_dim"] = embed_dim
return context
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_semantic_index.py
b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_semantic_index.py
index 2582653..135ce9c 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_semantic_index.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_semantic_index.py
@@ -28,27 +28,26 @@ from hugegraph_llm.utils.log import log
class BuildSemanticIndex:
def __init__(self, embedding: BaseEmbedding):
- self.content_file = str(os.path.join(resource_path,
settings.graph_name, "vid.pkl"))
- self.index_file = str(os.path.join(resource_path, settings.graph_name,
"vid.faiss"))
+ self.index_dir = str(os.path.join(resource_path, settings.graph_name,
"graph_vids"))
+ self.vid_index = VectorIndex.from_index_file(self.index_dir)
self.embedding = embedding
def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
- if len(context["vertices"]) > 0:
- log.debug("Building vector index for %s vertices...",
len(context["vertices"]))
- vids = []
- vids_embedding = []
- for vertex in tqdm(context["vertices"]):
- vertex_text = f"{vertex['label']}\n{vertex['properties']}"
-
vids_embedding.append(self.embedding.get_text_embedding(vertex_text))
- vids.append(vertex["id"])
- vids_embedding = [self.embedding.get_text_embedding(vid) for vid
in vids]
- log.debug("Vector index built for %s vertices.", len(vids))
- if os.path.exists(self.index_file) and
os.path.exists(self.content_file):
- vector_index = VectorIndex.from_index_file(self.index_file,
self.content_file)
- else:
- vector_index = VectorIndex(len(vids_embedding[0]))
- vector_index.add(vids_embedding, vids)
- vector_index.to_index_file(self.index_file, self.content_file)
+ past_vids = self.vid_index.properties
+ present_vids = [v["id"] for v in context["vertices"]]
+ removed_vids = set(past_vids) - set(present_vids)
+ removed_num = self.vid_index.remove(removed_vids)
+ added_vids = list(set(present_vids) - set(past_vids))
+ if len(added_vids) > 0:
+ log.debug("Building vector index for %s vertices...",
len(added_vids))
+ added_embeddings = [self.embedding.get_text_embedding(v) for v in
tqdm(added_vids)]
+ log.debug("Vector index built for %s vertices.",
len(added_embeddings))
+ self.vid_index.add(added_embeddings, added_vids)
+ self.vid_index.to_index_file(self.index_dir)
else:
log.debug("No vertices to build vector index.")
+ context.update({
+ "removed_vid_vector_num": removed_num,
+ "added_vid_vector_num": len(added_vids)
+ })
return context
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_vector_index.py
b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_vector_index.py
index 1fccf15..f11969e 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_vector_index.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_vector_index.py
@@ -29,10 +29,8 @@ from hugegraph_llm.utils.log import log
class BuildVectorIndex:
def __init__(self, embedding: BaseEmbedding):
self.embedding = embedding
- self.index_file = str(os.path.join(resource_path, settings.graph_name,
"vidx.faiss"))
- self.content_file = str(os.path.join(resource_path,
settings.graph_name, "vidx.pkl"))
- if not os.path.exists(os.path.join(resource_path,
settings.graph_name)):
- os.mkdir(os.path.join(resource_path, settings.graph_name))
+ self.index_dir = str(os.path.join(resource_path, settings.graph_name,
"chunks"))
+ self.vector_index = VectorIndex.from_index_file(self.index_dir)
def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
if "chunks" not in context:
@@ -43,10 +41,6 @@ class BuildVectorIndex:
for chunk in tqdm(chunks):
chunks_embedding.append(self.embedding.get_text_embedding(chunk))
if len(chunks_embedding) > 0:
- if os.path.exists(self.index_file) and
os.path.exists(self.content_file):
- vector_index = VectorIndex.from_index_file(self.index_file,
self.content_file)
- else:
- vector_index = VectorIndex(len(chunks_embedding[0]))
- vector_index.add(chunks_embedding, chunks)
- vector_index.to_index_file(self.index_file, self.content_file)
+ self.vector_index.add(chunks_embedding, chunks)
+ self.vector_index.to_index_file(self.index_dir)
return context
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py
b/hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py
index e38aa86..ddcf589 100644
---
a/hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py
+++
b/hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py
@@ -18,6 +18,7 @@
import os
from typing import Dict, Any
+
from hugegraph_llm.config import resource_path
from hugegraph_llm.models.embeddings.base import BaseEmbedding
from hugegraph_llm.indices.vector_index import VectorIndex
@@ -26,12 +27,10 @@ from hugegraph_llm.indices.vector_index import VectorIndex
class GremlinExampleIndexQuery:
def __init__(self, query: str, embedding: BaseEmbedding, num_examples: int
= 1):
self.query = query
- example_dir_name = ".gremlin_examples"
self.embedding = embedding
self.num_examples = num_examples
- index_file = str(os.path.join(resource_path, example_dir_name,
"index.faiss"))
- content_file = str(os.path.join(resource_path, example_dir_name,
"properties.pkl"))
- self.vector_index = VectorIndex.from_index_file(index_file,
content_file)
+ self.index_dir = os.path.join(resource_path, "gremlin_examples")
+ self.vector_index = VectorIndex.from_index_file(self.index_dir)
def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
context["query"] = self.query
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/index_op/semantic_id_query.py
b/hugegraph-llm/src/hugegraph_llm/operators/index_op/semantic_id_query.py
index 042d140..6cd8620 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/index_op/semantic_id_query.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/index_op/semantic_id_query.py
@@ -17,14 +17,17 @@
import os
-from typing import Dict, Any, Literal
+from copy import deepcopy
+from typing import Dict, Any, Literal, List, Tuple
+from pyhugegraph.client import PyHugeClient
from hugegraph_llm.config import resource_path, settings
from hugegraph_llm.indices.vector_index import VectorIndex
from hugegraph_llm.models.embeddings.base import BaseEmbedding
class SemanticIdQuery:
+ ID_QUERY_TEMPL = "g.V({vids_str})"
def __init__(
self,
embedding: BaseEmbedding,
@@ -32,13 +35,20 @@ class SemanticIdQuery:
topk_per_query: int = 10,
topk_per_keyword: int = 1
):
- index_file = str(os.path.join(resource_path, settings.graph_name,
"vid.faiss"))
- content_file = str(os.path.join(resource_path, settings.graph_name,
"vid.pkl"))
- self.vector_index = VectorIndex.from_index_file(index_file,
content_file)
+ self.index_dir = str(os.path.join(resource_path, settings.graph_name,
"graph_vids"))
+ self.vector_index = VectorIndex.from_index_file(self.index_dir)
self.embedding = embedding
self.by = by
self.topk_per_query = topk_per_query
self.topk_per_keyword = topk_per_keyword
+ self._client = PyHugeClient(
+ settings.graph_ip,
+ settings.graph_port,
+ settings.graph_name,
+ settings.graph_user,
+ settings.graph_pwd,
+ settings.graph_space,
+ )
def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
graph_query_entrance = []
@@ -49,11 +59,34 @@ class SemanticIdQuery:
if results:
graph_query_entrance.extend(results[:self.topk_per_query])
else: # by keywords
- keywords = context["keywords"]
- for keyword in keywords:
- keyword_vector = self.embedding.get_text_embedding(keyword)
- results = self.vector_index.search(keyword_vector,
top_k=self.topk_per_keyword)
- if results:
-
graph_query_entrance.extend(results[:self.topk_per_keyword])
+ exact_match_vids, unmatched_vids =
self._exact_match_vids(context["keywords"])
+ graph_query_entrance.extend(exact_match_vids)
+ fuzzy_match_vids = self._fuzzy_match_vids(unmatched_vids)
+ graph_query_entrance.extend(fuzzy_match_vids)
context["entrance_vids"] = list(set(graph_query_entrance))
return context
+
+ def _exact_match_vids(self, keywords: List[str]) -> Tuple[List[str],
List[str]]:
+ vertex_label_num = len(self._client.schema().getVertexLabels())
+ possible_vids = deepcopy(keywords)
+ for i in range(vertex_label_num):
+ possible_vids.extend([f"{i+1}:{keyword}" for keyword in keywords])
+ vids_str = ",".join([f"'{vid}'" for vid in possible_vids])
+ resp =
self._client.gremlin().exec(SemanticIdQuery.ID_QUERY_TEMPL.format(vids_str=vids_str))
+ searched_vids = [v['id'] for v in resp['data']]
+ unsearched_keywords = set(keywords)
+ for vid in searched_vids:
+ for keyword in unsearched_keywords:
+ if keyword in vid:
+ unsearched_keywords.remove(keyword)
+ break
+ return searched_vids, list(unsearched_keywords)
+
+ def _fuzzy_match_vids(self, keywords: List[str]) -> List[str]:
+ fuzzy_match_result = []
+ for keyword in keywords:
+ keyword_vector = self.embedding.get_text_embedding(keyword)
+ results = self.vector_index.search(keyword_vector,
top_k=self.topk_per_keyword)
+ if results:
+ fuzzy_match_result.extend(results[:self.topk_per_keyword])
+ return fuzzy_match_result
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/index_op/vector_index_query.py
b/hugegraph-llm/src/hugegraph_llm/operators/index_op/vector_index_query.py
index 16ed603..e845b61 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/index_op/vector_index_query.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/index_op/vector_index_query.py
@@ -18,6 +18,7 @@
import os
from typing import Dict, Any
+
from hugegraph_llm.config import resource_path, settings
from hugegraph_llm.models.embeddings.base import BaseEmbedding
from hugegraph_llm.indices.vector_index import VectorIndex
@@ -27,9 +28,8 @@ class VectorIndexQuery:
def __init__(self, embedding: BaseEmbedding, topk: int = 3):
self.embedding = embedding
self.topk = topk
- index_file = str(os.path.join(resource_path, settings.graph_name,
"vidx.faiss"))
- content_file = str(os.path.join(resource_path, settings.graph_name,
"vidx.pkl"))
- self.vector_index = VectorIndex.from_index_file(index_file,
content_file)
+ self.index_dir = str(os.path.join(resource_path, settings.graph_name,
"chunks"))
+ self.vector_index = VectorIndex.from_index_file(self.index_dir)
def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
query = context.get("query")
diff --git a/hugegraph-llm/src/hugegraph_llm/operators/kg_construction_task.py
b/hugegraph-llm/src/hugegraph_llm/operators/kg_construction_task.py
index 9b32977..921895f 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/kg_construction_task.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/kg_construction_task.py
@@ -52,7 +52,7 @@ class KgBuilder:
elif from_extraction:
raise Exception("Not implemented yet")
else:
- raise Exception("No input data")
+ raise Exception("No input data / invalid schema type")
return self
def fetch_graph_data(self):
@@ -98,8 +98,7 @@ class KgBuilder:
@log_time("total time")
@record_qps
- def run(self) -> Dict[str, Any]:
- context = None
+ def run(self, context: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
for operator in self.operators:
context = self._run_operator(operator, context)
return context
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py
b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py
index 129b77b..fda6eb7 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py
@@ -24,7 +24,7 @@ from hugegraph_llm.models.llms.init_llm import LLMs
from hugegraph_llm.config import prompt
-DEFAULT_ANSWER_TEMPLATE = prompt.default_answer_template
+DEFAULT_ANSWER_TEMPLATE = prompt.answer_prompt
class AnswerSynthesize:
@@ -77,25 +77,26 @@ class AnswerSynthesize:
response = self._llm.generate(prompt=prompt)
return {"answer": response}
- vector_result = context.get("vector_result", [])
- if len(vector_result) == 0:
- vector_result_context = "No (vector)phrase related to the query."
- else:
+ vector_result = context.get("vector_result")
+ if vector_result:
vector_result_context = "Phrases related to the query:\n" +
"\n".join(
f"{i + 1}. {res}" for i, res in enumerate(vector_result)
)
- graph_result = context.get("graph_result", [])
- if len(graph_result) == 0:
- graph_result_context = "No knowledge found in HugeGraph for the
query."
else:
+ vector_result_context = "No (vector)phrase related to the query."
+
+ graph_result = context.get("graph_result")
+ if graph_result:
graph_context_head = context.get("graph_context_head",
"The following are knowledge from
HugeGraph related to the query:\n")
graph_result_context = graph_context_head + "\n".join(
f"{i + 1}. {res}" for i, res in enumerate(graph_result)
)
+ else:
+ graph_result_context = "No related knowledge found in graph for
the query."
+
context = asyncio.run(self.async_generate(context, context_head_str,
context_tail_str,
vector_result_context,
graph_result_context))
-
return context
async def async_generate(self, context: Dict[str, Any], context_head_str:
str,
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py
b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py
index e4854c1..3249670 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py
@@ -121,7 +121,10 @@ class KeywordExtract:
for k in re.split(r"[,,]+", match):
k = k.strip()
if len(k) > 0:
- keywords.append(k.lower())
+ if lowercase:
+ keywords.append(k.lower())
+ else:
+ keywords.append(k)
# if the keyword consists of multiple words, split into sub-words
# (removing stopwords)
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/property_graph_extract.py
b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/property_graph_extract.py
index d080d56..33be1cd 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/property_graph_extract.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/property_graph_extract.py
@@ -25,7 +25,7 @@ from hugegraph_llm.document.chunk_split import ChunkSplitter
from hugegraph_llm.models.llms.base import BaseLLM
from hugegraph_llm.utils.log import log
-SCHEMA_EXAMPLE_PROMPT = prompt.schema_example_prompt
+SCHEMA_EXAMPLE_PROMPT = prompt.extract_graph_prompt
def generate_extract_property_graph_prompt(text, schema=None) -> str:
return f"""---
@@ -45,6 +45,39 @@ def split_text(text: str) -> List[str]:
return chunks
+def filter_item(schema, items) -> List[Dict[str, Any]]:
+ # filter vertex and edge with invalid properties
+ filtered_items = []
+ properties_map = {"vertex": {}, "edge": {}}
+ for vertex in schema["vertexlabels"]:
+ properties_map["vertex"][vertex["name"]] = {
+ "primary_keys": vertex["primary_keys"],
+ "nullable_keys": vertex["nullable_keys"],
+ "properties": vertex["properties"]
+ }
+ for edge in schema["edgelabels"]:
+ properties_map["edge"][edge["name"]] = {
+ "properties": edge["properties"]
+ }
+ log.info("properties_map: %s", properties_map)
+ for item in items:
+ item_type = item["type"]
+ if item_type == "vertex":
+ label = item["label"]
+ non_nullable_keys = (
+ set(properties_map[item_type][label]["properties"])
+
.difference(set(properties_map[item_type][label]["nullable_keys"])))
+ for key in non_nullable_keys:
+ if key not in item["properties"]:
+ item["properties"][key] = "NULL"
+ for key, value in item["properties"].items():
+ if not isinstance(value, str):
+ item["properties"][key] = str(value)
+ filtered_items.append(item)
+
+ return filtered_items
+
+
class PropertyGraphExtract:
def __init__(
self,
@@ -67,7 +100,7 @@ class PropertyGraphExtract:
proceeded_chunk = self.extract_property_graph_by_llm(schema, chunk)
log.debug("[LLM] %s input: %s \n output:%s",
self.__class__.__name__, chunk, proceeded_chunk)
items.extend(self._extract_and_filter_label(schema,
proceeded_chunk))
- items = self.filter_item(schema, items)
+ items = filter_item(schema, items)
for item in items:
if item["type"] == "vertex":
context["vertices"].append(item)
@@ -97,52 +130,20 @@ class PropertyGraphExtract:
edge_label_set = {edge["name"] for edge in schema["edgelabels"]}
for item in property_graph:
if not isinstance(item, dict):
- log.warning("Invalid property graph item type %s.",
type(item))
+ log.warning("Invalid property graph item type '%s'.",
type(item))
continue
if not self.NECESSARY_ITEM_KEYS.issubset(item.keys()):
- log.warning("Invalid item keys %s.", item.keys())
+ log.warning("Invalid item keys '%s'.", item.keys())
continue
if item["type"] == "vertex" or item["type"] == "edge":
if (item["label"] not in vertex_label_set
and item["label"] not in edge_label_set):
- log.warning("Invalid item label %s has been ignored.",
item["label"])
+ log.warning("Invalid '%s' label '%s' has been
ignored.", item["type"], item["label"])
else:
items.append(item)
else:
- log.warning("Invalid item type %s has been ignored.",
item["type"])
+ log.warning("Invalid item type '%s' has been ignored.",
item["type"])
except json.JSONDecodeError:
log.critical("Invalid property graph! Please check the extracted
JSON data carefully")
return items
-
- def filter_item(self, schema, items) -> List[Dict[str, Any]]:
- # filter vertex and edge with invalid properties
- filtered_items = []
- properties_map = {"vertex": {}, "edge": {}}
- for vertex in schema["vertexlabels"]:
- properties_map["vertex"][vertex["name"]] = {
- "primary_keys": vertex["primary_keys"],
- "nullable_keys": vertex["nullable_keys"],
- "properties": vertex["properties"]
- }
- for edge in schema["edgelabels"]:
- properties_map["edge"][edge["name"]] = {
- "properties": edge["properties"]
- }
- log.info("properties_map: %s", properties_map)
- for item in items:
- item_type = item["type"]
- if item_type == "vertex":
- label = item["label"]
- non_nullable_keys = (
- set(properties_map[item_type][label]["properties"])
-
.difference(set(properties_map[item_type][label]["nullable_keys"])))
- for key in non_nullable_keys:
- if key not in item["properties"]:
- item["properties"][key] = "NULL"
- for key, value in item["properties"].items():
- if not isinstance(value, str):
- item["properties"][key] = str(value)
- filtered_items.append(item)
-
- return filtered_items
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/fetch_graph_data.py
b/hugegraph-llm/src/hugegraph_llm/resources/demo/css.py
similarity index 55%
copy from
hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/fetch_graph_data.py
copy to hugegraph-llm/src/hugegraph_llm/resources/demo/css.py
index c3dfc25..0c56a40 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/fetch_graph_data.py
+++ b/hugegraph-llm/src/hugegraph_llm/resources/demo/css.py
@@ -15,24 +15,18 @@
# specific language governing permissions and limitations
# under the License.
+CSS = """
+footer {
+ visibility: hidden
+}
-from typing import Optional, Dict, Any
+.code-container-edit {
+ max-height: 520px;
+ overflow-y: auto; /* enable scroll */
+}
-from pyhugegraph.client import PyHugeClient
-
-
-class FetchGraphData:
- def __init__(self, graph: PyHugeClient):
- self.graph = graph
-
- def run(self, context: Optional[Dict[str, Any]]) -> Dict[str, Any]:
- if "vertices" not in context:
- context["vertices"] = []
- vertices = self.graph.gremlin().exec("g.V()")["data"]
- for vertex in vertices:
- context["vertices"].append({
- "id": vertex["id"],
- "label": vertex["label"],
- "properties": vertex["properties"]
- })
- return context
+.code-container-show {
+ max-height: 250px;
+ overflow-y: auto; /* enable scroll */
+}
+"""
diff --git a/hugegraph-llm/src/hugegraph_llm/utils/decorators.py
b/hugegraph-llm/src/hugegraph_llm/utils/decorators.py
index d9b1df7..5f37d47 100644
--- a/hugegraph-llm/src/hugegraph_llm/utils/decorators.py
+++ b/hugegraph-llm/src/hugegraph_llm/utils/decorators.py
@@ -53,8 +53,7 @@ def log_time(msg: Optional[str] = "") -> Callable:
if asyncio.iscoroutinefunction(func):
return async_wrapper
- else:
- return sync_wrapper
+ return sync_wrapper
# handle "@log_time" usage -> better to use "@log_time()" instead
if callable(msg):
@@ -73,7 +72,7 @@ def log_operator_time(func: Callable) -> Callable:
# Only record time ≥ 0.01s (10ms)
if op_time >= 0.01:
log.debug("Operator %s finished in %.2f seconds",
operator.__class__.__name__, op_time)
- log.debug("Context:\n%s", result)
+ # log.debug("Context:\n%s", result)
return result
return wrapper
diff --git a/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py
b/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py
new file mode 100644
index 0000000..73bb813
--- /dev/null
+++ b/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py
@@ -0,0 +1,121 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+import json
+import os
+import traceback
+from typing import Dict, Any, Union
+
+import gradio as gr
+from .hugegraph_utils import get_hg_client, clean_hg_data
+from .log import log
+from .vector_index_utils import read_documents
+from ..config import resource_path, settings, prompt
+from ..indices.vector_index import VectorIndex
+from ..models.embeddings.init_embedding import Embeddings
+from ..models.llms.init_llm import LLMs
+from ..operators.kg_construction_task import KgBuilder
+
+
+def get_graph_index_info():
+ builder = KgBuilder(LLMs().get_llm(), Embeddings().get_embedding(),
get_hg_client())
+ context = builder.fetch_graph_data().run()
+ vector_index = VectorIndex.from_index_file(str(os.path.join(resource_path,
settings.graph_name, "graph_vids")))
+ context["vid_index"] = {
+ "embed_dim": vector_index.index.d,
+ "num_vectors": vector_index.index.ntotal,
+ "num_vids": len(vector_index.properties)
+ }
+ return json.dumps(context, ensure_ascii=False, indent=2)
+
+
+def clean_all_graph_index():
+ clean_hg_data()
+ VectorIndex.clean(str(os.path.join(resource_path, settings.graph_name,
"graph_vids")))
+ gr.Info("Clean graph index successfully!")
+
+
+def extract_graph(input_file, input_text, schema, example_prompt) -> str:
+ # update env variables: schema and example_prompt
+ if prompt.graph_schema != schema or prompt.extract_graph_prompt !=
example_prompt:
+ prompt.graph_schema = schema
+ prompt.extract_graph_prompt = example_prompt
+ prompt.update_yaml_file()
+
+ texts = read_documents(input_file, input_text)
+ builder = KgBuilder(LLMs().get_llm(), Embeddings().get_embedding(),
get_hg_client())
+
+ if schema:
+ try:
+ schema = json.loads(schema.strip())
+ builder.import_schema(from_user_defined=schema)
+ except json.JSONDecodeError:
+ log.info("Get schema from graph!")
+ builder.import_schema(from_hugegraph=schema)
+ else:
+ return "ERROR: please input with correct schema/format."
+ builder.chunk_split(texts, "paragraph", "zh").extract_info(example_prompt,
"property_graph")
+
+ try:
+ context = builder.run()
+ graph_elements = {
+ "vertices": context["vertices"],
+ "edges": context["edges"]
+ }
+ return json.dumps(graph_elements, ensure_ascii=False, indent=2)
+ except Exception as e: # pylint: disable=broad-exception-caught
+ log.error(e)
+ raise gr.Error(str(e))
+
+
+def fit_vid_index():
+ builder = KgBuilder(LLMs().get_llm(), Embeddings().get_embedding(),
get_hg_client())
+ builder.fetch_graph_data().build_vertex_id_semantic_index()
+ log.debug(builder.operators)
+ try:
+ context = builder.run()
+ removed_num = context["removed_vid_vector_num"]
+ added_num = context["added_vid_vector_num"]
+ return f"Removed {removed_num} vectors, added {added_num} vectors."
+ except Exception as e: # pylint: disable=broad-exception-caught
+ log.error(e)
+ raise gr.Error(str(e))
+
+
+def import_graph_data(data: str, schema: str) -> Union[str, Dict[str, Any]]:
+ try:
+ data_json = json.loads(data.strip())
+ log.debug("Import graph data: %s", data)
+ builder = KgBuilder(LLMs().get_llm(), Embeddings().get_embedding(),
get_hg_client())
+ if schema:
+ try:
+ schema = json.loads(schema.strip())
+ builder.import_schema(from_user_defined=schema)
+ except json.JSONDecodeError:
+ log.info("Get schema from graph!")
+ builder.import_schema(from_hugegraph=schema)
+
+ context = builder.commit_to_hugegraph().run(data_json)
+ gr.Info("Import graph data successfully!")
+ return json.dumps(context, ensure_ascii=False, indent=2)
+ except Exception as e:
+ log.error(e)
+ traceback.print_exc()
+ # Note: can't use gr.Error here
+ gr.Warning(str(e) + " Please check the graph data format/type
carefully.")
+ return data
diff --git a/hugegraph-llm/src/hugegraph_llm/utils/hugegraph_utils.py
b/hugegraph-llm/src/hugegraph_llm/utils/hugegraph_utils.py
index 3320efa..f63f806 100644
--- a/hugegraph-llm/src/hugegraph_llm/utils/hugegraph_utils.py
+++ b/hugegraph-llm/src/hugegraph_llm/utils/hugegraph_utils.py
@@ -16,13 +16,13 @@
# under the License.
import json
-from pyhugegraph.client import PyHugeClient
from hugegraph_llm.config import settings
+from pyhugegraph.client import PyHugeClient
-def run_gremlin_query(query, format=False):
+def run_gremlin_query(query, fmt=True):
res = get_hg_client().gremlin().exec(query)
- return json.dumps(res, indent=4, ensure_ascii=False) if format else res
+ return json.dumps(res, indent=4, ensure_ascii=False) if fmt else res
def get_hg_client():
diff --git a/hugegraph-llm/src/hugegraph_llm/utils/vector_index_utils.py
b/hugegraph-llm/src/hugegraph_llm/utils/vector_index_utils.py
index ddb6dfc..edd1902 100644
--- a/hugegraph-llm/src/hugegraph_llm/utils/vector_index_utils.py
+++ b/hugegraph-llm/src/hugegraph_llm/utils/vector_index_utils.py
@@ -14,19 +14,62 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
-
+import json
import os
+import docx
+import gradio as gr
from hugegraph_llm.config import resource_path, settings
+from hugegraph_llm.indices.vector_index import VectorIndex
+from hugegraph_llm.models.embeddings.init_embedding import Embeddings
+from hugegraph_llm.models.llms.init_llm import LLMs
+from hugegraph_llm.operators.kg_construction_task import KgBuilder
+from hugegraph_llm.utils.hugegraph_utils import get_hg_client
+
+
+def read_documents(input_file, input_text):
+ if input_file:
+ texts = []
+ for file in input_file:
+ full_path = file.name
+ if full_path.endswith(".txt"):
+ with open(full_path, "r", encoding="utf-8") as f:
+ texts.append(f.read())
+ elif full_path.endswith(".docx"):
+ text = ""
+ doc = docx.Document(full_path)
+ for para in doc.paragraphs:
+ text += para.text
+ text += "\n"
+ texts.append(text)
+ elif full_path.endswith(".pdf"):
+ # TODO: support PDF file
+ raise gr.Error("PDF will be supported later! Try to upload
text/docx now")
+ else:
+ raise gr.Error("Please input txt or docx file.")
+ elif input_text:
+ texts = [input_text]
+ else:
+ raise gr.Error("Please input text or upload file.")
+ return texts
+
+
+def get_vector_index_info():
+ vector_index = VectorIndex.from_index_file(str(os.path.join(resource_path,
settings.graph_name, "chunks")))
+ return json.dumps({
+ "embed_dim": vector_index.index.d,
+ "num_vectors": vector_index.index.ntotal,
+ "num_properties": len(vector_index.properties)
+ }, ensure_ascii=False, indent=2)
def clean_vector_index():
- if os.path.exists(os.path.join(resource_path, settings.graph_name,
"vidx.faiss")):
- os.remove(os.path.join(resource_path, settings.graph_name,
"vidx.faiss"))
- if os.path.exists(os.path.join(resource_path, settings.graph_name,
"vidx.pkl")):
- os.remove(os.path.join(resource_path, settings.graph_name, "vidx.pkl"))
- if os.path.exists(os.path.join(resource_path, settings.graph_name,
"vid.faiss")):
- os.remove(os.path.join(resource_path, settings.graph_name,
"vid.faiss"))
- if os.path.exists(os.path.join(resource_path, settings.graph_name,
"vid.pkl")):
- os.remove(os.path.join(resource_path, settings.graph_name, "vid.pkl"))
+ VectorIndex.clean(str(os.path.join(resource_path, settings.graph_name,
"chunks")))
+ gr.Info("Clean vector index successfully!")
+
+
+def build_vector_index(input_file, input_text):
+ texts = read_documents(input_file, input_text)
+ builder = KgBuilder(LLMs().get_llm(), Embeddings().get_embedding(),
get_hg_client())
+ context = builder.chunk_split(texts, "paragraph",
"zh").build_vector_index().run()
+ return json.dumps(context, ensure_ascii=False, indent=2)
diff --git a/hugegraph-python-client/setup.py b/hugegraph-python-client/setup.py
index 12161f1..158e734 100644
--- a/hugegraph-python-client/setup.py
+++ b/hugegraph-python-client/setup.py
@@ -26,7 +26,7 @@ with open("requirements.txt", encoding="utf-8") as fp:
setuptools.setup(
name="hugegraph-python",
- version="1.3.0",
+ version="1.5.0",
author="Apache HugeGraph Contributors",
author_email="[email protected]",
install_requires=install_requires,
diff --git a/hugegraph-python-client/src/pyhugegraph/utils/huge_requests.py
b/hugegraph-python-client/src/pyhugegraph/utils/huge_requests.py
index 3cfe625..6951064 100644
--- a/hugegraph-python-client/src/pyhugegraph/utils/huge_requests.py
+++ b/hugegraph-python-client/src/pyhugegraph/utils/huge_requests.py
@@ -20,12 +20,13 @@ from typing import Any, Optional
from urllib.parse import urljoin
import requests
+from requests.adapters import HTTPAdapter
+from urllib3.util.retry import Retry
+
from pyhugegraph.utils.constants import Constants
from pyhugegraph.utils.huge_config import HGraphConfig
from pyhugegraph.utils.log import log
from pyhugegraph.utils.util import ResponseValidation
-from requests.adapters import HTTPAdapter
-from urllib3.util.retry import Retry
class HGraphSession:
diff --git a/hugegraph-python-client/src/pyhugegraph/utils/util.py
b/hugegraph-python-client/src/pyhugegraph/utils/util.py
index d90e6ba..f0a8015 100644
--- a/hugegraph-python-client/src/pyhugegraph/utils/util.py
+++ b/hugegraph-python-client/src/pyhugegraph/utils/util.py
@@ -48,7 +48,7 @@ def check_if_authorized(response):
def check_if_success(response, error=None):
if (not str(response.status_code).startswith("20")) and
check_if_authorized(
- response
+ response
):
if error is None:
error = NotFoundError(response.content)
@@ -85,7 +85,6 @@ class ResponseValidation:
try:
response.raise_for_status()
-
if response.status_code == 204:
log.debug("No content returned (204) for %s: %s", method, path)
else:
@@ -100,24 +99,21 @@ class ResponseValidation:
except requests.exceptions.HTTPError as e:
if not self._strict and response.status_code == 404:
- log.info( # pylint: disable=logging-fstring-interpolation
- f"Resource {path} not found (404)"
- )
+ log.info("Resource %s not found (404)", path)
else:
try:
- details = response.json().get(
- "exception", "key 'exception' not found"
- )
+ details = response.json().get("exception", "key
'exception' not found")
except (ValueError, KeyError):
details = "key 'exception' not found"
+ req_body = response.request.body if response.request.body else
"Empty body"
+ req_body = req_body.encode('utf-8').decode('unicode_escape')
log.error( # pylint: disable=logging-fstring-interpolation
- f"{method}: {e}\n[Body]: {response.request.body}\n[Server
Exception]: {details}"
+ f"{method}: {e}\n[Body]: {req_body}\n[Server Exception]:
{details}"
)
if response.status_code == 404:
raise NotFoundError(response.content) from e
-
raise e
except Exception: # pylint: disable=broad-exception-caught