Skip to content

Commit 6c1900f

Browse files
authored
feat: added --filter param to search command (#10)
1 parent 2265bbe commit 6c1900f

File tree

3 files changed

+263
-30
lines changed

3 files changed

+263
-30
lines changed

gptme_rag/cli.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,12 @@ def index(
358358
type=click.Choice(["cuda", "cpu"]),
359359
help="Device to run embeddings on (cuda or cpu)",
360360
)
361+
@click.option(
362+
"--filter",
363+
"-f",
364+
multiple=True,
365+
help="Filter results by path pattern (glob). Can be specified multiple times.",
366+
)
361367
def search(
362368
query: str,
363369
paths: list[Path],
@@ -371,6 +377,7 @@ def search(
371377
weights: str | None,
372378
embedding_function: str | None,
373379
device: str | None,
380+
filter: tuple[str, ...],
374381
):
375382
"""Search the index and assemble context."""
376383
paths = [path.resolve() for path in paths]
@@ -405,13 +412,27 @@ def search(
405412
device=device or "cpu",
406413
)
407414
assembler = ContextAssembler(max_tokens=max_tokens)
415+
416+
# Combine paths and filters for search
417+
search_paths = list(paths)
418+
if filter:
419+
# If no paths were specified but filters are present,
420+
# search from root and apply filters
421+
if not paths:
422+
search_paths = [Path(".")]
423+
logger.debug(f"Using path filters: {filter}")
424+
408425
if explain:
409426
documents, distances, explanations = indexer.search(
410-
query, n_results=n_results, paths=paths, explain=True
427+
query,
428+
n_results=n_results,
429+
paths=search_paths,
430+
path_filters=filter,
431+
explain=True,
411432
)
412433
else:
413434
documents, distances, _ = indexer.search(
414-
query, n_results=n_results, paths=paths
435+
query, n_results=n_results, paths=search_paths, path_filters=filter
415436
)
416437
finally:
417438
sys.stdout.close()

gptme_rag/indexing/indexer.py

Lines changed: 145 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -483,16 +483,85 @@ def compute_relevance_score(
483483

484484
return total_score, scores
485485

486-
def _matches_paths(self, doc: Document, paths: list[Path]) -> bool:
487-
"""Check if document matches any of the given paths."""
486+
def _matches_paths(
487+
self,
488+
doc: Document,
489+
paths: list[Path] | None = None,
490+
path_filters: tuple[str, ...] | None = None,
491+
) -> bool:
492+
"""Check if document matches any of the given paths or filters.
493+
494+
Args:
495+
doc: Document to check
496+
paths: List of paths to match against (exact path matching)
497+
path_filters: List of glob patterns to match against
498+
499+
Returns:
500+
bool: True if document matches any path or filter
501+
"""
488502
source = doc.metadata.get("source", "")
489503
if not source:
490504
return False
505+
491506
source_path = Path(source)
492-
return any(
493-
path.resolve() in source_path.parents or path.resolve() == source_path
494-
for path in paths
495-
)
507+
508+
path_match = True
509+
filter_match = True
510+
511+
# Check exact path matches if paths are specified
512+
if paths:
513+
path_match = any(
514+
path.resolve() in source_path.parents or path.resolve() == source_path
515+
for path in paths
516+
)
517+
if not path_match:
518+
logger.debug(f"Path match failed: {source_path} not in {paths}")
519+
return False
520+
521+
# Check pattern matches if filters are specified
522+
if path_filters:
523+
# Get both the full path and relative components for matching
524+
source_str = str(source_path)
525+
source_name = source_path.name
526+
source_parts = source_path.parts
527+
528+
filter_match = False # Set to True if any pattern matches
529+
for pattern in path_filters:
530+
logger.debug(f"Checking pattern: {pattern} against {source_str}")
531+
532+
# Handle different pattern types
533+
if pattern.startswith("*."):
534+
# Simple extension filter
535+
if source_name.endswith(pattern[1:]):
536+
logger.debug(f"Matched extension pattern: {pattern}")
537+
filter_match = True
538+
break
539+
else:
540+
# Convert pattern to parts for matching
541+
pattern_path = Path(pattern)
542+
pattern_parts = pattern_path.parts
543+
544+
# Try different matching strategies
545+
if (
546+
fnmatch_path(source_str, pattern)
547+
or fnmatch_path(source_str, f"**/{pattern}")
548+
or (
549+
len(pattern_parts) <= len(source_parts)
550+
and fnmatch_path(
551+
str(Path(*source_parts[-len(pattern_parts) :])), pattern
552+
)
553+
)
554+
):
555+
logger.debug(f"Matched path pattern: {pattern}")
556+
filter_match = True
557+
break
558+
559+
if not filter_match:
560+
logger.debug(f"No patterns matched: {source_str}")
561+
return False
562+
563+
# Both conditions must be met (if specified)
564+
return path_match and filter_match
496565

497566
def search(
498567
self,
@@ -503,16 +572,80 @@ def search(
503572
group_chunks: bool = True,
504573
max_attempts: int = 3,
505574
explain: bool = False,
575+
path_filters: tuple[str, ...] | None = None,
506576
) -> tuple[list[Document], list[float], list[dict[str, Any]] | None]:
507-
"""Search for documents similar to the query."""
577+
"""Search for documents similar to the query.
578+
579+
Args:
580+
query: The search query text
581+
paths: List of paths to search within (exact path matching)
582+
n_results: Maximum number of results to return
583+
where: Additional where clauses for ChromaDB query
584+
group_chunks: Whether to group chunks from the same document
585+
max_attempts: Maximum number of search attempts
586+
explain: Whether to return scoring explanations
587+
path_filters: Glob patterns to filter documents by path. Supports:
588+
- Simple extension filters (*.md, *.py)
589+
- Path patterns (src/*.py, docs/**/*.md)
590+
- Multiple patterns can be combined
591+
592+
Returns:
593+
Tuple of (documents, distances, explanations)
594+
- documents: List of matching Document objects
595+
- distances: List of embedding distances
596+
- explanations: List of scoring explanations (if explain=True)
597+
598+
Examples:
599+
# Search in markdown files
600+
search("query", path_filters=("*.md",))
601+
602+
# Search in Python files in src directory
603+
search("query", path_filters=("src/**/*.py",))
604+
605+
# Search in multiple file types
606+
search("query", path_filters=("*.md", "*.py"))
607+
608+
# Combine paths and filters
609+
search("query", paths=[Path("docs")], path_filters=("*.md",))
610+
"""
508611
# Get more results than needed to allow for filtering
509612
query_n_results = n_results * 3 if group_chunks else n_results
510613

614+
# Prepare where clause
615+
search_where = where.copy() if where else {}
616+
617+
# Pre-filter documents based on all patterns
618+
if path_filters:
619+
logger.debug(f"Filtering with patterns: {path_filters}")
620+
all_docs = self.collection.get()
621+
matching_sources = set()
622+
623+
for meta in all_docs["metadatas"]:
624+
if not meta or "source" not in meta:
625+
continue
626+
627+
source_path = Path(meta["source"])
628+
# Create a dummy document for path matching
629+
doc = Document(
630+
content="", metadata=meta, doc_id="temp", source_path=source_path
631+
)
632+
633+
# Use _matches_paths to check all patterns
634+
if self._matches_paths(doc, paths=None, path_filters=path_filters):
635+
matching_sources.add(str(source_path))
636+
637+
if matching_sources:
638+
logger.debug(f"Found {len(matching_sources)} matching files")
639+
search_where["source"] = {"$in": list(matching_sources)}
640+
else:
641+
logger.debug("No files matched the filter patterns")
642+
return [], [], [] if explain else None
643+
511644
# Query the collection
512645
results = self.collection.query(
513646
query_texts=[query],
514647
n_results=query_n_results,
515-
where=where,
648+
where=search_where,
516649
)
517650

518651
if not results["ids"][0]:
@@ -530,7 +663,7 @@ def search(
530663
metadata=results["metadatas"][0][i],
531664
doc_id=doc_id,
532665
)
533-
if not paths or self._matches_paths(doc, paths):
666+
if self._matches_paths(doc, paths, path_filters):
534667
docs_by_source[source_id] = (doc, results["distances"][0][i])
535668

536669
# Take top n results
@@ -541,7 +674,7 @@ def search(
541674
else:
542675
# Process individual chunks
543676
documents, distances, _ = self._process_individual_chunks(
544-
results, paths, n_results, explain
677+
results, paths, n_results, explain, path_filters
545678
)
546679

547680
# Add explanations if requested
@@ -564,6 +697,7 @@ def _process_individual_chunks(
564697
paths: list[Path] | None,
565698
n_results: int,
566699
explain: bool,
700+
path_filters: tuple[str, ...] | None = None,
567701
) -> tuple[list[Document], list[float], list[dict]]:
568702
"""Process search results as individual chunks."""
569703
documents: list[Document] = []
@@ -583,7 +717,7 @@ def _process_individual_chunks(
583717
doc_id=doc_id,
584718
)
585719

586-
if paths and not self._matches_paths(doc, paths):
720+
if not self._matches_paths(doc, paths, path_filters):
587721
continue
588722

589723
documents.append(doc)

tests/test_indexing.py

Lines changed: 95 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from pathlib import Path
2+
13
import pytest
24
from gptme_rag.indexing.document import Document
35

@@ -72,24 +74,100 @@ def test_indexer_add_documents(indexer, test_docs):
7274

7375

7476
def test_indexer_directory(indexer, tmp_path):
75-
# Create test files
76-
(tmp_path / "test1.txt").write_text("Content about Python")
77-
(tmp_path / "test2.txt").write_text("Content about JavaScript")
78-
(tmp_path / "subdir").mkdir()
79-
(tmp_path / "subdir" / "test3.txt").write_text("Content about TypeScript")
77+
# Create test files in different directories with different extensions
78+
docs_dir = tmp_path / "docs"
79+
src_dir = tmp_path / "src"
80+
docs_dir.mkdir()
81+
src_dir.mkdir()
8082

81-
indexer.index_directory(tmp_path)
83+
# Create markdown files in docs
84+
(docs_dir / "guide.md").write_text("Python programming guide")
85+
(docs_dir / "tutorial.md").write_text("JavaScript tutorial")
8286

83-
# Search for programming languages
84-
python_results, python_distances, _ = indexer.search("Python")
85-
js_results, js_distances, _ = indexer.search("JavaScript")
86-
ts_results, ts_distances, _ = indexer.search("TypeScript")
87+
# Create Python files in src
88+
(src_dir / "main.py").write_text("def main(): print('Hello')")
89+
(src_dir / "utils.py").write_text("def util(): return True")
8790

88-
assert len(python_results) > 0
89-
assert len(js_results) > 0
90-
assert len(ts_results) > 0
91+
# Create a text file in root
92+
(tmp_path / "notes.txt").write_text("Random notes")
93+
94+
# Index everything
95+
indexer.index_directory(tmp_path)
9196

92-
# Verify distances are returned
93-
assert len(python_distances) > 0
94-
assert len(js_distances) > 0
95-
assert len(ts_distances) > 0
97+
# Test extension filter (*.md)
98+
md_results, _, _ = indexer.search(
99+
"programming",
100+
path_filters=("*.md",),
101+
)
102+
assert len(md_results) > 0
103+
assert all(doc.metadata["source"].endswith(".md") for doc in md_results)
104+
105+
# Test directory pattern (src/*.py)
106+
py_results, _, _ = indexer.search(
107+
"def",
108+
path_filters=(str(src_dir / "*.py"),),
109+
)
110+
assert len(py_results) > 0
111+
assert all(
112+
Path(doc.metadata["source"]).parent.name == "src"
113+
and doc.metadata["source"].endswith(".py")
114+
for doc in py_results
115+
)
116+
117+
# Test multiple patterns
118+
multi_results, _, _ = indexer.search(
119+
"programming",
120+
path_filters=("*.md", "*.py"),
121+
)
122+
assert len(multi_results) > 0
123+
assert all(doc.metadata["source"].endswith((".md", ".py")) for doc in multi_results)
124+
125+
# Test with path and filter combined
126+
docs_md_results, _, _ = indexer.search(
127+
"tutorial",
128+
paths=[docs_dir],
129+
path_filters=("*.md",),
130+
)
131+
assert len(docs_md_results) > 0
132+
assert all(
133+
Path(doc.metadata["source"]).parent.name == "docs"
134+
and doc.metadata["source"].endswith(".md")
135+
for doc in docs_md_results
136+
)
137+
138+
139+
def test_path_matching(indexer):
140+
# Test the _matches_paths method directly
141+
doc = Document(
142+
content="test",
143+
metadata={"source": "/home/user/project/docs/guide.md"},
144+
doc_id="test",
145+
)
146+
147+
# Test simple extension filter
148+
assert indexer._matches_paths(doc, path_filters=("*.md",))
149+
assert not indexer._matches_paths(doc, path_filters=("*.py",))
150+
151+
# Test directory pattern
152+
assert indexer._matches_paths(doc, path_filters=("docs/*.md",))
153+
assert not indexer._matches_paths(doc, path_filters=("src/*.md",))
154+
155+
# Test multiple patterns
156+
assert indexer._matches_paths(doc, path_filters=("*.py", "*.md"))
157+
assert indexer._matches_paths(doc, path_filters=("src/*.py", "docs/*.md"))
158+
159+
# Test with exact paths
160+
assert indexer._matches_paths(doc, paths=[Path("/home/user/project/docs")])
161+
assert not indexer._matches_paths(doc, paths=[Path("/home/user/project/src")])
162+
163+
# Test combining paths and filters
164+
assert indexer._matches_paths(
165+
doc,
166+
paths=[Path("/home/user/project/docs")],
167+
path_filters=("*.md",),
168+
)
169+
assert not indexer._matches_paths(
170+
doc,
171+
paths=[Path("/home/user/project/docs")],
172+
path_filters=("*.py",),
173+
)

0 commit comments

Comments
 (0)