Skip to content

Commit 5c6ddfb

Browse files
authored
Add Support for Knowledge Base Retrieval Params & Preprocessing (#61)
1 parent d9bef29 commit 5c6ddfb

File tree

3 files changed

+102
-5
lines changed

3 files changed

+102
-5
lines changed

minds/knowledge_bases/knowledge_bases.py

+17-4
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from pydantic import BaseModel
44

5+
from minds.knowledge_bases.preprocessing import PreprocessingConfig
56
from minds.rest_api import RestAPI
67

78

@@ -25,6 +26,8 @@ class KnowledgeBaseConfig(BaseModel):
2526
description: str
2627
vector_store_config: Optional[VectorStoreConfig] = None
2728
embedding_config: Optional[EmbeddingConfig] = None
29+
# Params to apply to retrieval pipeline.
30+
params: Optional[Dict] = None
2831

2932

3033
class KnowledgeBaseDocument(BaseModel):
@@ -39,7 +42,7 @@ def __init__(self, name, api: RestAPI):
3942
self.name = name
4043
self.api = api
4144

42-
def insert_from_select(self, query: str):
45+
def insert_from_select(self, query: str, preprocessing_config: PreprocessingConfig = None):
4346
'''
4447
Inserts select content of a connected datasource into this knowledge base
4548
@@ -48,9 +51,11 @@ def insert_from_select(self, query: str):
4851
update_request = {
4952
'query': query
5053
}
54+
if preprocessing_config is not None:
55+
update_request['preprocessing'] = preprocessing_config.model_dump()
5156
_ = self.api.put(f'/knowledge_bases/{self.name}', data=update_request)
5257

53-
def insert_documents(self, documents: List[KnowledgeBaseDocument]):
58+
def insert_documents(self, documents: List[KnowledgeBaseDocument], preprocessing_config: PreprocessingConfig = None):
5459
'''
5560
Inserts documents directly into this knowledge base
5661
@@ -59,9 +64,11 @@ def insert_documents(self, documents: List[KnowledgeBaseDocument]):
5964
update_request = {
6065
'rows': [d.model_dump() for d in documents]
6166
}
67+
if preprocessing_config is not None:
68+
update_request['preprocessing'] = preprocessing_config.model_dump()
6269
_ = self.api.put(f'/knowledge_bases/{self.name}', data=update_request)
6370

64-
def insert_urls(self, urls: List[str]):
71+
def insert_urls(self, urls: List[str], preprocessing_config: PreprocessingConfig = None):
6572
'''
6673
Crawls URLs & inserts the retrieved webpages into this knowledge base
6774
@@ -70,9 +77,11 @@ def insert_urls(self, urls: List[str]):
7077
update_request = {
7178
'urls': urls
7279
}
80+
if preprocessing_config is not None:
81+
update_request['preprocessing'] = preprocessing_config.model_dump()
7382
_ = self.api.put(f'/knowledge_bases/{self.name}', data=update_request)
7483

75-
def insert_files(self, files: List[str]):
84+
def insert_files(self, files: List[str], preprocessing_config: PreprocessingConfig = None):
7685
'''
7786
Inserts files that have already been uploaded to MindsDB into this knowledge base
7887
@@ -81,6 +90,8 @@ def insert_files(self, files: List[str]):
8190
update_request = {
8291
'files': files
8392
}
93+
if preprocessing_config is not None:
94+
update_request['preprocessing'] = preprocessing_config.model_dump()
8495
_ = self.api.put(f'/knowledge_bases/{self.name}', data=update_request)
8596

8697

@@ -117,6 +128,8 @@ def create(self, config: KnowledgeBaseConfig) -> KnowledgeBase:
117128
if config.embedding_config.params is not None:
118129
embedding_data.update(config.embedding_config.params)
119130
create_request['embedding_model'] = embedding_data
131+
if config.params is not None:
132+
create_request['params'] = config.params
120133

121134
_ = self.api.post('/knowledge_bases', data=create_request)
122135
return self.get(config.name)
+78
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from typing import Any, Dict, List, Literal, Optional
2+
3+
from pydantic import BaseModel, Field, model_validator
4+
5+
6+
DEFAULT_LLM_MODEL = 'gpt-4o'
7+
DEFAULT_LLM_MODEL_PROVIDER = 'openai'
8+
9+
10+
class TextChunkingConfig(BaseModel):
11+
'''Configuration for chunking text content before they are inserted into a knowledge base'''
12+
separators: List[str] = Field(
13+
default=['\n\n', '\n', ' ', ''],
14+
description='List of separators to use for splitting text, in order of priority'
15+
)
16+
chunk_size: int = Field(
17+
default=1000,
18+
description='The target size of each text chunk',
19+
gt=0
20+
)
21+
chunk_overlap: int = Field(
22+
default=200,
23+
description='The number of characters to overlap between chunks',
24+
ge=0
25+
)
26+
27+
28+
class LLMConfig(BaseModel):
29+
model_name: str = Field(default=DEFAULT_LLM_MODEL, description='LLM model to use for context generation')
30+
provider: str = Field(default=DEFAULT_LLM_MODEL_PROVIDER, description='LLM model provider to use for context generation')
31+
params: Dict[str, Any] = Field(default={}, description='Additional parameters to pass in when initializing the LLM')
32+
33+
34+
class ContextualConfig(BaseModel):
35+
'''Configuration specific to contextual preprocessing'''
36+
llm_config: LLMConfig = Field(
37+
default=LLMConfig(),
38+
description='LLM configuration to use for context generation'
39+
)
40+
context_template: Optional[str] = Field(
41+
default=None,
42+
description='Custom template for context generation'
43+
)
44+
chunk_size: int = Field(
45+
default=1000,
46+
description='The target size of each text chunk',
47+
gt=0
48+
)
49+
chunk_overlap: int = Field(
50+
default=200,
51+
description='The number of characters to overlap between chunks',
52+
ge=0
53+
)
54+
55+
56+
class PreprocessingConfig(BaseModel):
57+
'''Complete preprocessing configuration'''
58+
type: Literal['contextual', 'text_chunking'] = Field(
59+
default='text_chunking',
60+
description='Type of preprocessing to apply'
61+
)
62+
contextual_config: Optional[ContextualConfig] = Field(
63+
default=None,
64+
description='Configuration for contextual preprocessing'
65+
)
66+
text_chunking_config: Optional[TextChunkingConfig] = Field(
67+
default=None,
68+
description='Configuration for text chunking preprocessing'
69+
)
70+
71+
@model_validator(mode='after')
72+
def validate_config_presence(self) -> 'PreprocessingConfig':
73+
'''Ensure the appropriate config is present for the chosen type'''
74+
if self.type == 'contextual' and not self.contextual_config:
75+
self.contextual_config = ContextualConfig()
76+
if self.type == 'text_chunking' and not self.text_chunking_config:
77+
self.text_chunking_config = TextChunkingConfig()
78+
return self

tests/unit/test_unit.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,10 @@ def test_create_knowledge_bases(self, mock_post, mock_get):
130130
name='test_kb',
131131
description='Test knowledge base',
132132
vector_store_config=test_vector_store_config,
133-
embedding_config=test_embedding_config
133+
embedding_config=test_embedding_config,
134+
params={
135+
'k1': 'v1'
136+
}
134137
)
135138
response_mock(mock_get, test_knowledge_base_config.model_dump())
136139

@@ -152,6 +155,9 @@ def test_create_knowledge_bases(self, mock_post, mock_get):
152155
'provider': test_embedding_config.provider,
153156
'name': test_embedding_config.model,
154157
'k1': 'v1'
158+
},
159+
'params': {
160+
'k1': 'v1'
155161
}
156162
}
157163

0 commit comments

Comments
 (0)