Skip to content

Commit 03e3732

Browse files
authored
Make code_embedding example deal with various languages. (#145)
#109
1 parent f28b7df commit 03e3732

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

examples/code_embedding/main.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,19 @@
11
from dotenv import load_dotenv
22

33
import cocoindex
4+
import os
5+
6+
class ExtractExtension(cocoindex.op.FunctionSpec):
7+
"""Summarize a Python module."""
8+
9+
@cocoindex.op.executor_class()
10+
class ExtractExtensionExecutor:
11+
"""Executor for ExtractExtension."""
12+
13+
spec: ExtractExtension
14+
15+
def __call__(self, filename: str) -> str:
16+
return os.path.splitext(filename)[1]
417

518
def code_to_embedding(text: cocoindex.DataSlice) -> cocoindex.DataSlice:
619
"""
@@ -17,14 +30,15 @@ def code_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoind
1730
"""
1831
data_scope["files"] = flow_builder.add_source(
1932
cocoindex.sources.LocalFile(path="../..",
20-
included_patterns=["*.py"],
21-
excluded_patterns=[".*"]))
33+
included_patterns=["*.py", "*.rs", "*.toml", "*.md", "*.mdx"],
34+
excluded_patterns=[".*", "target", "**/node_modules"]))
2235
code_embeddings = data_scope.add_collector()
2336

2437
with data_scope["files"].row() as file:
38+
file["extension"] = file["filename"].transform(ExtractExtension())
2539
file["chunks"] = file["content"].transform(
2640
cocoindex.functions.SplitRecursively(),
27-
language="python", chunk_size=1000, chunk_overlap=300)
41+
language=file["extension"], chunk_size=1000, chunk_overlap=300)
2842
with file["chunks"].row() as chunk:
2943
chunk["embedding"] = chunk["text"].call(code_to_embedding)
3044
code_embeddings.collect(filename=file["filename"], location=chunk["location"],

src/ops/functions/split_recursively.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ static TREE_SITTER_LANGUAGE_BY_LANG: LazyLock<HashMap<UniCase<&'static str>, Arc
117117
add_language(
118118
&mut map,
119119
"Markdown",
120-
[".md", "md"],
120+
[".md", ".mdx", "md"],
121121
tree_sitter_md::LANGUAGE,
122122
["inline"],
123123
);

0 commit comments

Comments
 (0)