Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow to use customized GraphRAG settings.yaml #387

Merged
merged 6 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ GRAPHRAG_API_KEY=<YOUR_OPENAI_KEY>
GRAPHRAG_LLM_MODEL=gpt-4o-mini
GRAPHRAG_EMBEDDING_MODEL=text-embedding-3-small

# set to true if you want to use customized GraphRAG config file
USE_CUSTOMIZED_GRAPHRAG_SETTING=false

# settings for Azure DI
AZURE_DI_ENDPOINT=
AZURE_DI_CREDENTIAL=
Expand Down
33 changes: 31 additions & 2 deletions libs/ktem/ktem/index/file/graph/pipelines.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import shutil
import subprocess
from pathlib import Path
from shutil import rmtree
Expand All @@ -7,6 +8,8 @@

import pandas as pd
import tiktoken
import yaml
from decouple import config
from ktem.db.models import engine
from sqlalchemy.orm import Session
from theflow.settings import settings
Expand Down Expand Up @@ -116,6 +119,16 @@ def call_graphrag_index(self, input_path: str):
print(result.stdout)
command = command[:-1]

# copy customized GraphRAG config file if it exists
if config("USE_CUSTOMIZED_GRAPHRAG_SETTING", default="value").lower() == "true":
setting_file_path = os.path.join(os.getcwd(), "settings.yaml.example")
destination_file_path = os.path.join(input_path, "settings.yaml")
try:
shutil.copy(setting_file_path, destination_file_path)
except shutil.Error:
# Handle the error if the file copy fails
print("failed to copy customized GraphRAG config file. ")

# Run the command and stream stdout
with subprocess.Popen(command, stdout=subprocess.PIPE, text=True) as process:
if process.stdout:
Expand Down Expand Up @@ -221,12 +234,28 @@ def _build_graph_search(self):
text_unit_df = pd.read_parquet(f"{INPUT_DIR}/{TEXT_UNIT_TABLE}.parquet")
text_units = read_indexer_text_units(text_unit_df)

# initialize default settings
embedding_model = os.getenv(
"GRAPHRAG_EMBEDDING_MODEL", "text-embedding-3-small"
)
embedding_api_key = os.getenv("GRAPHRAG_API_KEY")
embedding_api_base = None

# use customized GraphRAG settings if the flag is set
if config("USE_CUSTOMIZED_GRAPHRAG_SETTING", default="value").lower() == "true":
settings_yaml_path = Path(root_path) / "settings.yaml"
with open(settings_yaml_path, "r") as f:
settings = yaml.safe_load(f)
if settings["embeddings"]["llm"]["model"]:
embedding_model = settings["embeddings"]["llm"]["model"]
if settings["embeddings"]["llm"]["api_key"]:
embedding_api_key = settings["embeddings"]["llm"]["api_key"]
if settings["embeddings"]["llm"]["api_base"]:
embedding_api_base = settings["embeddings"]["llm"]["api_base"]

text_embedder = OpenAIEmbedding(
api_key=os.getenv("GRAPHRAG_API_KEY"),
api_base=None,
api_key=embedding_api_key,
api_base=embedding_api_base,
api_type=OpenaiApiType.OpenAI,
model=embedding_model,
deployment_name=embedding_model,
Expand Down
159 changes: 159 additions & 0 deletions settings.yaml.example
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# This is a sample GraphRAG settings.yaml file that allows users to run the GraphRAG index process with their customized parameters.
# The parameters in this file will only take effect when the USE_CUSTOMIZED_GRAPHRAG_SETTING is true in .env file.
# For a comprehensive understanding of GraphRAG parameters, please refer to: https://microsoft.github.io/graphrag/config/json_yaml/.

encoding_model: cl100k_base
skip_workflows: []
llm:
api_key: ${GRAPHRAG_API_KEY}
type: openai_chat # or azure_openai_chat
api_base: http://127.0.0.1:11434/v1
model: qwen2
model_supports_json: true # recommended if this is available for your model.
# max_tokens: 4000
request_timeout: 1800.0
# api_base: https://<instance>.openai.azure.com
# api_version: 2024-02-15-preview
# organization: <organization_id>
# deployment_name: <azure_model_deployment_name>
# tokens_per_minute: 150_000 # set a leaky bucket throttle
# requests_per_minute: 10_000 # set a leaky bucket throttle
# max_retries: 10
# max_retry_wait: 10.0
# sleep_on_rate_limit_recommendation: true # whether to sleep when azure suggests wait-times
concurrent_requests: 5 # the number of parallel inflight requests that may be made
# temperature: 0 # temperature for sampling
# top_p: 1 # top-p sampling
# n: 1 # Number of completions to generate

parallelization:
stagger: 0.3
# num_threads: 50 # the number of threads to use for parallel processing

async_mode: threaded # or asyncio

embeddings:
## parallelization: override the global parallelization settings for embeddings
async_mode: threaded # or asyncio
# target: required # or all
# batch_size: 16 # the number of documents to send in a single request
# batch_max_tokens: 8191 # the maximum number of tokens to send in a single request
llm:
api_base: http://localhost:11434/v1
api_key: ${GRAPHRAG_API_KEY}
model: nomic-embed-text
type: openai_embedding
# api_base: https://<instance>.openai.azure.com
# api_version: 2024-02-15-preview
# organization: <organization_id>
# deployment_name: <azure_model_deployment_name>
# tokens_per_minute: 150_000 # set a leaky bucket throttle
# requests_per_minute: 10_000 # set a leaky bucket throttle
# max_retries: 10
# max_retry_wait: 10.0
# sleep_on_rate_limit_recommendation: true # whether to sleep when azure suggests wait-times
# concurrent_requests: 25 # the number of parallel inflight requests that may be made

chunks:
size: 1200
overlap: 100
group_by_columns: [id] # by default, we don't allow chunks to cross documents

input:
type: file # or blob
file_type: text # or csv
base_dir: "input"
file_encoding: utf-8
file_pattern: ".*\\.txt$"

cache:
type: file # or blob
base_dir: "cache"
# connection_string: <azure_blob_storage_connection_string>
# container_name: <azure_blob_storage_container_name>

storage:
type: file # or blob
base_dir: "output"
# connection_string: <azure_blob_storage_connection_string>
# container_name: <azure_blob_storage_container_name>

reporting:
type: file # or console, blob
base_dir: "output"
# connection_string: <azure_blob_storage_connection_string>
# container_name: <azure_blob_storage_container_name>

entity_extraction:
## strategy: fully override the entity extraction strategy.
## type: one of graph_intelligence, graph_intelligence_json and nltk
## llm: override the global llm settings for this task
## parallelization: override the global parallelization settings for this task
## async_mode: override the global async_mode settings for this task
prompt: "prompts/entity_extraction.txt"
entity_types: [organization,person,geo,event]
max_gleanings: 1

summarize_descriptions:
## llm: override the global llm settings for this task
## parallelization: override the global parallelization settings for this task
## async_mode: override the global async_mode settings for this task
prompt: "prompts/summarize_descriptions.txt"
max_length: 500

claim_extraction:
## llm: override the global llm settings for this task
## parallelization: override the global parallelization settings for this task
## async_mode: override the global async_mode settings for this task
# enabled: true
prompt: "prompts/claim_extraction.txt"
description: "Any claims or facts that could be relevant to information discovery."
max_gleanings: 1

community_reports:
## llm: override the global llm settings for this task
## parallelization: override the global parallelization settings for this task
## async_mode: override the global async_mode settings for this task
prompt: "prompts/community_report.txt"
max_length: 2000
max_input_length: 8000

cluster_graph:
max_cluster_size: 10

embed_graph:
enabled: false # if true, will generate node2vec embeddings for nodes
# num_walks: 10
# walk_length: 40
# window_size: 2
# iterations: 3
# random_seed: 597832

umap:
enabled: false # if true, will generate UMAP embeddings for nodes

snapshots:
graphml: false
raw_entities: false
top_level_nodes: false

local_search:
# text_unit_prop: 0.5
# community_prop: 0.1
# conversation_history_max_turns: 5
# top_k_mapped_entities: 10
# top_k_relationships: 10
# llm_temperature: 0 # temperature for sampling
# llm_top_p: 1 # top-p sampling
# llm_n: 1 # Number of completions to generate
# max_tokens: 12000

global_search:
# llm_temperature: 0 # temperature for sampling
# llm_top_p: 1 # top-p sampling
# llm_n: 1 # Number of completions to generate
# max_tokens: 12000
# data_max_tokens: 12000
# map_max_tokens: 1000
# reduce_max_tokens: 2000
# concurrency: 32
Loading