Skip to content

perf: Optimize page loading #954

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 39 additions & 22 deletions src/tagstudio/core/library/alchemy/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import shutil
import time
import unicodedata
from collections.abc import Iterator
from collections.abc import Iterable, Iterator
from dataclasses import dataclass
from datetime import UTC, datetime
from os import makedirs
Expand Down Expand Up @@ -170,26 +170,26 @@ class SearchResult:

Attributes:
total_count(int): total number of items for given query, might be different than len(items).
items(list[Entry]): for current page (size matches filter.page_size).
ids(list[int]): for current page (size matches filter.page_size).
"""

total_count: int
items: list[Entry]
ids: list[int]

def __bool__(self) -> bool:
"""Boolean evaluation for the wrapper.

:return: True if there are items in the result.
:return: True if there are ids in the result.
"""
return self.total_count > 0

def __len__(self) -> int:
"""Return the total number of items in the result."""
return len(self.items)
"""Return the total number of ids in the result."""
return len(self.ids)

def __getitem__(self, index: int) -> Entry:
"""Allow to access items via index directly on the wrapper."""
return self.items[index]
def __getitem__(self, index: int) -> int:
"""Allow to access ids via index directly on the wrapper."""
return self.ids[index]


@dataclass
Expand Down Expand Up @@ -611,7 +611,7 @@ def apply_db9_schema_changes(self, session: Session):

def apply_db9_filename_population(self, session: Session):
"""Populate the filename column introduced in DB_VERSION 9."""
for entry in self.get_entries():
for entry in self.all_entries():
session.merge(entry).filename = entry.path.name
session.commit()
logger.info("[Library][Migration] Populated filename column in entries table")
Expand Down Expand Up @@ -692,6 +692,12 @@ def get_entry_full(
entry.tags = tags
return entry

def get_entries(self, entry_ids: Iterable[int]) -> list[Entry]:
with Session(self.engine) as session:
statement = select(Entry).where(Entry.id.in_(entry_ids))
entries = dict((e.id, e) for e in session.scalars(statement))
return [entries[id] for id in entry_ids]

def get_entries_full(self, entry_ids: list[int] | set[int]) -> Iterator[Entry]:
"""Load entry and join with all joins and all tags."""
with Session(self.engine) as session:
Expand Down Expand Up @@ -746,12 +752,25 @@ def get_entry_full_by_path(self, path: Path) -> Entry | None:
make_transient(entry)
return entry

def get_tag_entries(
self, tag_ids: Iterable[int], entry_ids: Iterable[int]
) -> dict[int, set[int]]:
"""Returns a dict of tag_id->(entry_ids with tag_id)."""
tag_entries: dict[int, set[int]] = dict((id, set()) for id in tag_ids)
with Session(self.engine) as session:
statement = select(TagEntry).where(
and_(TagEntry.tag_id.in_(tag_ids), TagEntry.entry_id.in_(entry_ids))
)
for tag_entry in session.scalars(statement).fetchall():
tag_entries[tag_entry.tag_id].add(tag_entry.entry_id)
return tag_entries

@property
def entries_count(self) -> int:
with Session(self.engine) as session:
return session.scalar(select(func.count(Entry.id)))

def get_entries(self, with_joins: bool = False) -> Iterator[Entry]:
def all_entries(self, with_joins: bool = False) -> Iterator[Entry]:
"""Load entries without joins."""
with Session(self.engine) as session:
stmt = select(Entry)
Expand Down Expand Up @@ -868,7 +887,7 @@ def search_library(
assert self.engine

with Session(self.engine, expire_on_commit=False) as session:
statement = select(Entry)
statement = select(Entry.id, func.count().over())

if search.ast:
start_time = time.time()
Expand All @@ -886,13 +905,6 @@ def search_library(
elif extensions:
statement = statement.where(Entry.suffix.in_(extensions))

statement = statement.distinct(Entry.id)
start_time = time.time()
query_count = select(func.count()).select_from(statement.alias("entries"))
count_all: int = session.execute(query_count).scalar() or 0
end_time = time.time()
logger.info(f"finished counting ({format_timespan(end_time - start_time)})")

sort_on: ColumnExpressionArgument = Entry.id
match search.sorting_mode:
case SortingModeEnum.DATE_ADDED:
Expand All @@ -912,13 +924,18 @@ def search_library(
)

start_time = time.time()
items = session.scalars(statement).fetchall()
rows = session.execute(statement).fetchall()
ids = []
count = 0
for row in rows:
id, count = row._tuple()
ids.append(id)
end_time = time.time()
logger.info(f"SQL Execution finished ({format_timespan(end_time - start_time)})")

res = SearchResult(
total_count=count_all,
items=list(items),
total_count=count,
ids=ids,
)

session.expunge_all()
Expand Down
5 changes: 3 additions & 2 deletions src/tagstudio/core/utils/dupe_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,13 @@ def refresh_dupe_files(self, results_filepath: str | Path):
results = self.library.search_library(
BrowsingState.from_path(path_relative), 500
)
entries = self.library.get_entries(results.ids)

if not results:
# file not in library
continue

files.append(results[0])
files.append(entries[0])

if not len(files) > 1:
# only one file in the group, nothing to do
Expand All @@ -79,7 +80,7 @@ def merge_dupe_entries(self):
)

for i, entries in enumerate(self.groups):
remove_ids = [x.id for x in entries[1:]]
remove_ids = entries[1:]
logger.info("Removing entries group", ids=remove_ids)
self.library.remove_entries(remove_ids)
yield i - 1 # The -1 waits for the next step to finish
2 changes: 1 addition & 1 deletion src/tagstudio/core/utils/missing_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def refresh_missing_files(self) -> Iterator[int]:
"""Track the number of entries that point to an invalid filepath."""
logger.info("[refresh_missing_files] Refreshing missing files...")
self.missing_file_entries = []
for i, entry in enumerate(self.library.get_entries()):
for i, entry in enumerate(self.library.all_entries()):
full_path = self.library.library_dir / entry.path
if not full_path.exists() or not full_path.is_file():
self.missing_file_entries.append(entry)
Expand Down
4 changes: 2 additions & 2 deletions src/tagstudio/qt/modals/folders_to_tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def add_tag_to_tree(items: list[Tag]):
reversed_tag = reverse_tag(library, tag, None)
add_tag_to_tree(reversed_tag)

for entry in library.get_entries():
for entry in library.all_entries():
folders = entry.path.parts[0:-1]
if not folders:
continue
Expand Down Expand Up @@ -123,7 +123,7 @@ def _add_folders_to_tree(items: Sequence[str]) -> BranchData:
reversed_tag = reverse_tag(library, tag, None)
add_tag_to_tree(reversed_tag)

for entry in library.get_entries():
for entry in library.all_entries():
folders = entry.path.parts[0:-1]
if not folders:
continue
Expand Down
28 changes: 5 additions & 23 deletions src/tagstudio/qt/ts_qt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1425,7 +1425,8 @@ def update_thumbs(self):
logger.info("[QtDriver] Loading Entries...")
# TODO: The full entries with joins don't need to be grabbed here.
# Use a method that only selects the frame content but doesn't include the joins.
entries: list[Entry] = list(self.lib.get_entries_full(self.frame_content))
entries = self.lib.get_entries(self.frame_content)
tag_entries = self.lib.get_tag_entries([TAG_ARCHIVED, TAG_FAVORITE], self.frame_content)
logger.info("[QtDriver] Building Filenames...")
filenames: list[Path] = [self.lib.library_dir / e.path for e in entries]
logger.info("[QtDriver] Done! Processing ItemThumbs...")
Expand Down Expand Up @@ -1471,27 +1472,8 @@ def update_thumbs(self):
(time.time(), filenames[index], base_size, ratio, is_loading, is_grid_thumb),
)
)
item_thumb.assign_badge(BadgeType.ARCHIVED, entry.is_archived)
item_thumb.assign_badge(BadgeType.FAVORITE, entry.is_favorite)
item_thumb.update_clickable(
clickable=(
lambda checked=False, item_id=entry.id: self.toggle_item_selection(
item_id,
append=(
QGuiApplication.keyboardModifiers()
== Qt.KeyboardModifier.ControlModifier
),
bridge=(
QGuiApplication.keyboardModifiers() == Qt.KeyboardModifier.ShiftModifier
),
)
)
)
item_thumb.delete_action.triggered.connect(
lambda checked=False, f=filenames[index], e_id=entry.id: self.delete_files_callback(
f, e_id
)
)
item_thumb.assign_badge(BadgeType.ARCHIVED, entry.id in tag_entries[TAG_ARCHIVED])
item_thumb.assign_badge(BadgeType.FAVORITE, entry.id in tag_entries[TAG_FAVORITE])

# Restore Selected Borders
is_selected = item_thumb.item_id in self.selected
Expand Down Expand Up @@ -1576,7 +1558,7 @@ def update_browsing_state(self, state: BrowsingState | None = None) -> None:
)

# update page content
self.frame_content = [item.id for item in results.items]
self.frame_content = results.ids
self.update_thumbs()

# update pagination
Expand Down
13 changes: 11 additions & 2 deletions src/tagstudio/qt/widgets/item_thumb.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import structlog
from PIL import Image, ImageQt
from PySide6.QtCore import QEvent, QMimeData, QSize, Qt, QUrl
from PySide6.QtGui import QAction, QDrag, QEnterEvent, QPixmap
from PySide6.QtGui import QAction, QDrag, QEnterEvent, QGuiApplication, QPixmap
from PySide6.QtWidgets import (
QBoxLayout,
QCheckBox,
Expand Down Expand Up @@ -321,7 +321,16 @@ def __init__(

self.base_layout.addWidget(self.thumb_container)
self.base_layout.addWidget(self.file_label)

self.thumb_button.clicked.connect(
lambda: self.driver.toggle_item_selection(
self.item_id,
append=(QGuiApplication.keyboardModifiers() == Qt.KeyboardModifier.ControlModifier),
bridge=(QGuiApplication.keyboardModifiers() == Qt.KeyboardModifier.ShiftModifier),
)
)
self.delete_action.triggered.connect(
lambda: self.driver.delete_files_callback(self.opener.filepath, self.item_id)
)
self.set_mode(mode)

@property
Expand Down
5 changes: 3 additions & 2 deletions src/tagstudio/qt/widgets/thumb_button.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,5 +136,6 @@ def leaveEvent(self, event: QEvent) -> None: # noqa: N802
return super().leaveEvent(event)

def set_selected(self, value: bool) -> None: # noqa: N802
self.selected = value
self.repaint()
if value != self.selected:
self.selected = value
self.repaint()
6 changes: 3 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,12 @@ def search_library() -> Library:

@pytest.fixture
def entry_min(library):
yield next(library.get_entries())
yield next(library.all_entries())


@pytest.fixture
def entry_full(library: Library):
yield next(library.get_entries(with_joins=True))
yield next(library.all_entries(with_joins=True))


@pytest.fixture
Expand Down Expand Up @@ -168,7 +168,7 @@ class Args:
driver.lib = library
# TODO - downsize this method and use it
# driver.start()
driver.frame_content = list(library.get_entries())
driver.frame_content = list(library.all_entries())
yield driver


Expand Down
2 changes: 1 addition & 1 deletion tests/macros/test_folders_tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@

def test_folders_to_tags(library):
folders_to_tags(library)
entry = [x for x in library.get_entries(with_joins=True) if "bar.md" in str(x.path)][0]
entry = [x for x in library.all_entries(with_joins=True) if "bar.md" in str(x.path)][0]
assert {x.name for x in entry.tags} == {"two", "bar"}
3 changes: 2 additions & 1 deletion tests/macros/test_missing_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,5 @@ def test_refresh_missing_files(library: Library):

# `bar.md` should be relinked to new correct path
results = library.search_library(BrowsingState.from_path("bar.md"), page_size=500)
assert results[0].path == Path("bar.md")
entries = library.get_entries(results.ids)
assert entries[0].path == Path("bar.md")
10 changes: 5 additions & 5 deletions tests/qt/test_field_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_add_tag_to_selection_single(qt_driver, library, entry_full):
panel.fields.add_tags_to_selected(2000)

# Then reload entry
refreshed_entry = next(library.get_entries(with_joins=True))
refreshed_entry = next(library.all_entries(with_joins=True))
assert {t.id for t in refreshed_entry.tags} == {1000, 2000}


Expand All @@ -71,13 +71,13 @@ def test_add_same_tag_to_selection_single(qt_driver, library, entry_full):
panel.fields.add_tags_to_selected(1000)

# Then reload entry
refreshed_entry = next(library.get_entries(with_joins=True))
refreshed_entry = next(library.all_entries(with_joins=True))
assert {t.id for t in refreshed_entry.tags} == {1000}


def test_add_tag_to_selection_multiple(qt_driver, library):
panel = PreviewPanel(library, qt_driver)
all_entries = library.get_entries(with_joins=True)
all_entries = library.all_entries(with_joins=True)

# We want to verify that tag 1000 is on some, but not all entries already.
tag_present_on_some: bool = False
Expand All @@ -93,15 +93,15 @@ def test_add_tag_to_selection_multiple(qt_driver, library):
assert tag_absent_on_some

# Select the multiple entries
for i, e in enumerate(library.get_entries(with_joins=True), start=0):
for i, e in enumerate(library.all_entries(with_joins=True), start=0):
qt_driver.toggle_item_selection(e.id, append=(True if i == 0 else False), bridge=False) # noqa: SIM210
panel.update_widgets()

# Add new tag
panel.fields.add_tags_to_selected(1000)

# Then reload all entries and recheck the presence of tag 1000
refreshed_entries = library.get_entries(with_joins=True)
refreshed_entries = library.all_entries(with_joins=True)
tag_present_on_some: bool = False
tag_absent_on_some: bool = False

Expand Down
2 changes: 1 addition & 1 deletion tests/qt/test_qt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@

def test_browsing_state_update(qt_driver: "QtDriver"):
# Given
for entry in qt_driver.lib.get_entries(with_joins=True):
for entry in qt_driver.lib.all_entries(with_joins=True):
thumb = ItemThumb(ItemType.ENTRY, qt_driver.lib, qt_driver, (100, 100))
qt_driver.item_thumbs.append(thumb)
qt_driver.frame_content.append(entry)
Expand Down
Loading