Skip to content

Commit 885bc68

Browse files
authored
Kz/287 feat nested fields as columns (#308)
1 parent 9c24e7f commit 885bc68

File tree

4 files changed

+215
-21
lines changed

4 files changed

+215
-21
lines changed

packages/ragbits-cli/src/ragbits/cli/state.py

+53-18
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
from collections.abc import Mapping, Sequence
33
from dataclasses import dataclass
44
from enum import Enum
5-
from typing import TypeVar
5+
from types import UnionType
6+
from typing import Optional, TypeVar, Union, get_args, get_origin
67

78
import typer
89
from pydantic import BaseModel
10+
from pydantic.fields import FieldInfo
911
from rich.console import Console
1012
from rich.table import Column, Table
1113

@@ -48,39 +50,72 @@ def print_output_table(
4850
console.print("No results")
4951
return
5052

51-
fields = {**data[0].model_fields, **data[0].model_computed_fields}
52-
53-
# Human readable titles for columns
54-
titles = {
55-
key: value.get("title", key)
56-
for key, value in data[0].model_json_schema(mode="serialization")["properties"].items()
57-
}
53+
base_fields = {**data[0].model_fields, **data[0].model_computed_fields}
5854

5955
# Normalize the list of columns
6056
if columns is None:
61-
columns = {key: Column() for key in fields}
57+
columns = {key: Column() for key in base_fields}
6258
elif isinstance(columns, str):
6359
columns = {key: Column() for key in columns.split(",")}
6460
elif isinstance(columns, Sequence):
6561
columns = {key: Column() for key in columns}
6662

67-
# Add headers to columns if not provided
68-
for key in columns:
69-
if key not in fields:
70-
Console(stderr=True).print(f"Unknown column: {key}")
71-
raise typer.Exit(1)
72-
73-
column = columns[key]
63+
# check if columns are correct
64+
for column_name in columns:
65+
field = _get_nested_field(column_name, base_fields)
66+
column = columns[column_name]
7467
if column.header == "":
75-
column.header = titles.get(key, key)
68+
column.header = field.title if field.title else column_name.replace("_", " ").replace(".", " ").title()
7669

7770
# Create and print the table
7871
table = Table(*columns.values(), show_header=True, header_style="bold magenta")
72+
7973
for row in data:
80-
table.add_row(*[str(getattr(row, key)) for key in columns])
74+
row_to_add = []
75+
for key in columns:
76+
*path_fragments, field_name = key.strip().split(".")
77+
base_row = row
78+
for fragment in path_fragments:
79+
base_row = getattr(base_row, fragment)
80+
z = getattr(base_row, field_name)
81+
row_to_add.append(str(z))
82+
table.add_row(*row_to_add)
83+
8184
console.print(table)
8285

8386

87+
def _get_nested_field(column_name: str, base_fields: dict) -> FieldInfo:
88+
"""
89+
Check if column name exists in the model schema.
90+
91+
Args:
92+
column_name: name of the column to check
93+
base_fields: model fields
94+
Returns:
95+
field: nested field
96+
"""
97+
fields = base_fields
98+
*path_fragments, field_name = column_name.strip().split(".")
99+
for fragment in path_fragments:
100+
if fragment not in fields:
101+
Console(stderr=True).print(
102+
f"Unknown column: {'.'.join(path_fragments + [field_name])} ({fragment} not found)"
103+
)
104+
raise typer.Exit(1)
105+
model_class = fields[fragment].annotation
106+
if get_origin(model_class) in [UnionType, Optional, Union]:
107+
types = get_args(model_class)
108+
model_class = next((t for t in types if t is not type(None)), None)
109+
if model_class and issubclass(model_class, BaseModel):
110+
fields = {**model_class.model_fields, **model_class.model_computed_fields}
111+
if field_name not in fields:
112+
Console(stderr=True).print(
113+
f"Unknown column: {'.'.join(path_fragments + [field_name])} ({field_name} not found)"
114+
)
115+
raise typer.Exit(1)
116+
return fields[field_name]
117+
118+
84119
def print_output_json(data: Sequence[ModelT]) -> None:
85120
"""
86121
Display data from Pydantic models in a JSON format.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
from pathlib import Path
2+
from unittest.mock import MagicMock, patch
3+
4+
import pytest
5+
import typer
6+
from pydantic import BaseModel
7+
from pydantic.fields import Field, FieldInfo
8+
from rich.table import Column, Table
9+
10+
from ragbits.cli.state import OutputType, _get_nested_field, print_output, print_output_table
11+
from ragbits.document_search.documents.sources import LocalFileSource
12+
13+
14+
class InnerTestModel(BaseModel):
15+
id: int
16+
name: str = Field(title="Name of the inner model", description="Name of the inner model")
17+
location: LocalFileSource
18+
19+
20+
class OtherTestModel(BaseModel):
21+
id: int
22+
name: str
23+
location: InnerTestModel
24+
25+
26+
class MainTestModel(BaseModel):
27+
id: int
28+
name: str
29+
model: OtherTestModel | None
30+
31+
32+
data = [
33+
MainTestModel(
34+
id=1,
35+
name="A",
36+
model=OtherTestModel(
37+
id=11,
38+
name="aa",
39+
location=InnerTestModel(id=111, name="aa1", location=LocalFileSource(path=Path("folder_1"))),
40+
),
41+
),
42+
MainTestModel(
43+
id=2,
44+
name="B",
45+
model=OtherTestModel(
46+
id=22,
47+
name="bb",
48+
location=InnerTestModel(id=222, name="aa2", location=LocalFileSource(path=Path("folder_2"))),
49+
),
50+
),
51+
]
52+
53+
54+
@patch("ragbits.cli.state.print_output_table")
55+
@patch("ragbits.cli.state.print_output_json")
56+
def test_print_output_text(mock_print_output_json: MagicMock, mock_print_output_table: MagicMock):
57+
with patch("ragbits.cli.state.cli_state") as mock_cli_state:
58+
mock_cli_state.output_type = OutputType.text
59+
columns = {"id": Column(), "name": Column()}
60+
print_output(data, columns=columns)
61+
mock_print_output_table.assert_called_once_with(data, columns)
62+
mock_print_output_json.assert_not_called()
63+
64+
65+
@patch("ragbits.cli.state.print_output_table")
66+
@patch("ragbits.cli.state.print_output_json")
67+
def test_print_output_json(mock_print_output_json: MagicMock, mock_print_output_table: MagicMock):
68+
with patch("ragbits.cli.state.cli_state") as mock_cli_state:
69+
mock_cli_state.output_type = OutputType.json
70+
print_output(data)
71+
mock_print_output_table.assert_not_called()
72+
mock_print_output_json.assert_called_once_with(data)
73+
74+
75+
def test_print_output_unsupported_output_type():
76+
with patch("ragbits.cli.state.cli_state") as mock_cli_state:
77+
mock_cli_state.output_type = "unsupported_type"
78+
with pytest.raises(ValueError, match="Unsupported output type: unsupported_type"):
79+
print_output(data)
80+
81+
82+
def test_print_output_table():
83+
with patch("rich.console.Console.print") as mock_print:
84+
columns = {"id": Column(), "model.location.location.path": Column(), "model.location.name": Column()}
85+
print_output_table(data, columns)
86+
mock_print.assert_called_once()
87+
args, _ = mock_print.call_args_list[0]
88+
printed_table = args[0]
89+
assert isinstance(printed_table, Table)
90+
assert printed_table.columns[0].header == "Id"
91+
assert printed_table.columns[1].header == "Model Location Location Path"
92+
assert printed_table.columns[2].header == "Name of the inner model"
93+
assert printed_table.row_count == 2
94+
95+
96+
def test_get_nested_field():
97+
column = "model.location.location.path"
98+
fields = {"name": FieldInfo(annotation=str), "model": FieldInfo(annotation=OtherTestModel)}
99+
100+
try:
101+
result = _get_nested_field(column, fields)
102+
assert result.annotation == Path
103+
except typer.Exit:
104+
pytest.fail("typer.Exit was raised unexpectedly")
105+
106+
107+
def test_get_nested_field_wrong_field():
108+
column_names = [
109+
("model.location.wrong_field", "wrong_field"),
110+
("model.wrong_path.location.path", "wrong_path"),
111+
("wrong_path.location.location.path", "wrong_path"),
112+
("model.location.path", "path"),
113+
("model.location.location.path.additional_field", "additional_field"),
114+
]
115+
fields = {"name": FieldInfo(annotation=str), "model": FieldInfo(annotation=OtherTestModel)}
116+
117+
for wrong_column, wrong_fragment in column_names:
118+
with patch("rich.console.Console.print") as mock_print:
119+
with pytest.raises(typer.Exit, match="1"):
120+
_get_nested_field(wrong_column, fields)
121+
mock_print.assert_called_once_with(f"Unknown column: {wrong_column} ({wrong_fragment} not found)")

packages/ragbits-core/tests/unit/utils/test_config_handling.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def example_factory() -> ExampleClassWithConfigMixin:
3232
return ExampleSubclass("aligator", 42)
3333

3434

35-
def test_defacult_from_config():
35+
def test_default_from_config():
3636
config = {"foo": "foo", "bar": 1}
3737
instance = ExampleClassWithConfigMixin.from_config(config)
3838
assert instance.foo == "foo"
@@ -113,7 +113,6 @@ def test_subclass_from_defaults_instance_yaml():
113113
subproject="core",
114114
current_dir=projects_dir / "project_with_instances_yaml",
115115
)
116-
print(config)
117116
instance = ExampleClassWithConfigMixin.subclass_from_defaults(config)
118117
assert isinstance(instance, ExampleSubclass)
119118
assert instance.foo == "I am a foo"

packages/ragbits-document-search/tests/cli/test_search.py

+40-1
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,29 @@ def test_search_columns():
6969
assert "<DocumentType.TXT: 'txt'>" in result.stdout
7070

7171

72+
def test_search_nested_columns():
73+
runner = CliRunner(mix_stderr=False)
74+
result = runner.invoke(
75+
ds_app,
76+
[
77+
"--factory-path",
78+
factory_path,
79+
"search",
80+
"example query",
81+
"--columns",
82+
"location,location.coordinates,location.page_number",
83+
],
84+
)
85+
assert result.exit_code == 0
86+
print(result.stdout)
87+
assert "Foo document" not in result.stdout
88+
assert "Bar document" not in result.stdout
89+
assert "Baz document" not in result.stdout
90+
assert "Location" in result.stdout
91+
assert "Location Coordinates" in result.stdout
92+
assert "Location Page Number" in result.stdout
93+
94+
7295
def test_search_columns_non_existent():
7396
runner = CliRunner(mix_stderr=False)
7497
result = runner.invoke(
@@ -86,6 +109,23 @@ def test_search_columns_non_existent():
86109
assert "Unknown column: non_existent" in result.stderr
87110

88111

112+
def test_search_nested_columns_non_existent():
113+
runner = CliRunner(mix_stderr=False)
114+
result = runner.invoke(
115+
ds_app,
116+
[
117+
"--factory-path",
118+
factory_path,
119+
"search",
120+
"example query",
121+
"--columns",
122+
"document_meta,location,location.non_existent",
123+
],
124+
)
125+
assert result.exit_code == 1
126+
assert "Unknown column: location.non_existent" in result.stderr
127+
128+
89129
def test_search_json():
90130
autoregister()
91131
runner = CliRunner(mix_stderr=False)
@@ -101,7 +141,6 @@ def test_search_json():
101141
"example query",
102142
],
103143
)
104-
print(result.stderr)
105144
assert result.exit_code == 0
106145
elements = json.loads(result.stdout)
107146
assert len(elements) == 3

0 commit comments

Comments
 (0)