Skip to content

Commit ec0734c

Browse files
clean: Minor cleanups (#200)
1 parent aa716fe commit ec0734c

18 files changed

+92
-146
lines changed

docs/create_cli_docs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# intended to be used for anything other than a starting point.
33
# at least we would need this issue fixed first:
44
# https://github.com/explosion/radicli/issues/30
5+
from __future__ import annotations
56

67
from pathlib import Path
78

docs/create_desc_stats.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
from pathlib import Path
24

35
import pandas as pd

docs/update_benchmark_tables.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
"""
2-
Script for running the benchmark and pushing the results to Datawrapper.
1+
"""Script for running the benchmark and pushing the results to Datawrapper.
32
43
Example:
5-
64
python update_benchmark_tables.py --data-wrapper-api-token <token>
75
"""
86

7+
from __future__ import annotations
8+
99
import argparse
1010
from collections import defaultdict
1111
from collections.abc import Sequence
@@ -69,7 +69,7 @@ def create_mdl_name_w_reference(mdl: seb.ModelMeta) -> str:
6969
return mdl_name
7070

7171

72-
def get_speed_results(model_meta: seb.ModelMeta) -> Optional[float]:
72+
def get_speed_results(model_meta: seb.ModelMeta) -> float | None:
7373
model = seb.get_model(model_meta.name)
7474
TOKENS_IN_UGLY_DUCKLING = 3591
7575

@@ -218,9 +218,7 @@ def push_to_datawrapper(df: pd.DataFrame, chart_id: str, token: str):
218218

219219

220220
def compute_avg_rank(df: pd.DataFrame) -> pd.Series:
221-
"""
222-
For each model in the dataset, for each task, compute the rank of the model and then compute the average rank.
223-
"""
221+
"""For each model in the dataset, for each task, compute the rank of the model and then compute the average rank."""
224222
df = df.drop(columns=["Average Score", "Open Source", "Embedding Size", "Model name", "WPS (CPU)"])
225223

226224
ranks = df.rank(axis=0, ascending=False, na_option="bottom")
@@ -229,9 +227,7 @@ def compute_avg_rank(df: pd.DataFrame) -> pd.Series:
229227

230228

231229
def compute_avg_rank_bootstrap(df: pd.DataFrame, n_samples: int = 100) -> pd.Series:
232-
"""
233-
For all models bootstrap a set of tasks and compute the average rank. Repeat this n_samples times.
234-
"""
230+
"""For all models bootstrap a set of tasks and compute the average rank. Repeat this n_samples times."""
235231
df = df.drop(columns=["Average Score", "Open Source", "Embedding Size", "Average Rank", "WPS (CPU)", "Model name"])
236232
tasks = np.array(df.columns.tolist())
237233
n_tasks = len(tasks)

pyproject.toml

Lines changed: 25 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ version = "0.13.11"
88
authors = [
99
{ name = "Kenneth Enevoldsen", email = "Kennethcenevoldsen@gmail.com" },
1010
]
11+
license = { file = "LICENSE" }
1112
description = "Scandinavian Embedding Benchmark"
1213
classifiers = [
1314
"Operating System :: POSIX :: Linux",
@@ -29,11 +30,6 @@ dependencies = [
2930
"psutil>=7.0.0",
3031
]
3132

32-
33-
[project.license]
34-
file = "LICENSE"
35-
name = "MIT"
36-
3733
[project.optional-dependencies]
3834
mistral = [
3935
"transformers>=4.31.0", # lower bound required for mistral models (could potentially be lowered)
@@ -73,21 +69,25 @@ exclude = [".*venv*"]
7369
pythonPlatform = "Darwin"
7470

7571
[tool.ruff]
76-
# extend-include = ["*.ipynb"]
7772
line-length = 150
73+
target-version = "py39"
7874

79-
# Enable pycodestyle (`E`) and Pyflakes (`F`) codes by default.
80-
lint.select = [
75+
[tool.ruff.lint]
76+
select = [
77+
"F", # pyflakes rules,
78+
"I", # sorting for imports
79+
"E", # formatting for docs
80+
"D", # formatting for docs
81+
"UP", # upgrade to latest syntax if possible
82+
"FA", # Future annotations
83+
"C4", # cleaner comprehensions
8184
"A",
8285
"ANN",
8386
"ARG",
8487
"B",
85-
"C4",
8688
"COM",
8789
"D417",
88-
"E",
8990
"ERA",
90-
"F",
9191
"I",
9292
"ICN",
9393
"NPY001",
@@ -105,7 +105,7 @@ lint.select = [
105105
"SIM",
106106
"W",
107107
]
108-
lint.ignore = [
108+
ignore = [
109109
"ANN101",
110110
"ANN102",
111111
"ANN401",
@@ -116,36 +116,19 @@ lint.ignore = [
116116
"F841",
117117
"RET504",
118118
"COM812",
119+
"D100", # Missing docstring in public module
120+
"D101", # Missing docstring in public class
121+
"D102", # Missing docstring in public method
122+
"D103", # Missing docstring in public function
123+
"D105", # Missing docstring in magic method
124+
"D104", # Missing docstring in public package
125+
"D107", # Missing docstring in __init__
119126
]
120127
# Allow autofix for all enabled rules (when `--fix`) is provided.
121-
lint.unfixable = ["ERA"]
122-
# Exclude a variety of commonly ignored directories.
123-
lint.exclude = [
124-
".bzr",
125-
".direnv",
126-
".eggs",
127-
".git",
128-
".hg",
129-
".nox",
130-
".pants.d",
131-
".pytype",
132-
".ruff_cache",
133-
".svn",
134-
".tox",
135-
".venv",
136-
"__pypackages__",
137-
"_build",
138-
"buck-out",
139-
"build",
140-
"dist",
141-
"node_modules",
142-
"venv",
143-
"__init__.py",
144-
"docs/conf.py",
145-
]
128+
unfixable = ["ERA"]
129+
exclude = [".venv"]
146130
# Allow unused variables when underscore-prefixed.
147-
lint.dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
148-
target-version = "py39"
131+
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
149132

150133
[tool.ruff.lint.flake8-annotations]
151134
mypy-init-return = true
@@ -154,9 +137,8 @@ suppress-none-returning = true
154137
[tool.ruff.lint.pydocstyle]
155138
convention = "google"
156139

157-
[tool.ruff.lint.mccabe]
158-
# Unlike Flake8, default to a complexity level of 10.
159-
max-complexity = 10
140+
[tool.ruff.lint.isort]
141+
required-imports = ["from __future__ import annotations"]
160142

161143
[tool.semantic_release]
162144
branch = "main"
@@ -167,8 +149,8 @@ build_command = "python -m pip install build; python -m build"
167149
include-package-data = true
168150

169151
[tool.uv]
170-
default-groups = ["dev", "tests", "docs"]
171152
conflicts = [[{ extra = "sonar" }, { extra = "arctic" }]]
153+
default-groups = ["dev", "tests", "docs"]
172154
no-build-isolation-package = ["xformers", "flash-attn"]
173155

174156
[dependency-groups]
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
{"task_name":"LCC","task_description":"The leipzig corpora collection, annotated for sentiment","task_version":"1.1.1","time_of_run":"2024-05-21T09:44:03.564974","scores":{"da":{"accuracy":0.3846666666666666,"f1":0.3650136884557438,"accuracy_stderr":0.03664241622309678,"f1_stderr":0.03540233062350939,"main_score":0.3846666666666666}},"main_score":"accuracy"}
1+
{"task_name":"LCC","task_description":"The leipzig corpora collection, annotated for sentiment","task_version":"1.1.1","time_of_run":"2025-05-23T16:02:19.308605","scores":{"da":{"accuracy":0.38533333333333325,"f1":0.3657168079255128,"accuracy_stderr":0.036490485822410684,"f1_stderr":0.03512881865293476,"main_score":0.38533333333333325}},"main_score":"accuracy"}

src/seb/interfaces/task.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
from typing import Literal, Protocol, TypedDict, runtime_checkable
24

35
import numpy as np
@@ -36,8 +38,7 @@ class DescriptiveDatasetStats(TypedDict):
3638

3739
@runtime_checkable
3840
class Task(Protocol):
39-
"""
40-
A task is a specific evaluation task for a sentence embedding model.
41+
"""A task is a specific evaluation task for a sentence embedding model.
4142
4243
Attributes:
4344
name: The name of the task.
@@ -62,8 +63,7 @@ class Task(Protocol):
6263
description: str
6364

6465
def evaluate(self, model: Encoder) -> TaskResult:
65-
"""
66-
Evaluates a Sentence Embedding Model on the task.
66+
"""Evaluates a Sentence Embedding Model on the task.
6767
6868
Args:
6969
model: A model with the encode method implemented.
@@ -74,8 +74,7 @@ def evaluate(self, model: Encoder) -> TaskResult:
7474
...
7575

7676
def get_documents(self) -> list[str]:
77-
"""
78-
Get the documents for the task.
77+
"""Get the documents for the task.
7978
8079
Returns:
8180
A list of strings.
@@ -95,8 +94,6 @@ def get_descriptive_stats(self) -> DescriptiveDatasetStats:
9594
)
9695

9796
def name_to_path(self) -> str:
98-
"""
99-
Convert a name to a path.
100-
"""
97+
"""Convert a name to a path."""
10198
name = self.name.replace("/", "__").replace(" ", "_")
10299
return name

src/seb/registered_models/cohere_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def get_embedding_dim(self) -> int:
2727
return v.shape[1]
2828

2929
def _embed(self, sentences: list[str], input_type: str) -> torch.Tensor:
30-
import cohere
30+
import cohere # type: ignore[import]
3131

3232
client = cohere.Client()
3333
response = client.embed(

src/seb/registered_models/llm2vec_models.py

Lines changed: 9 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,21 @@
11
from __future__ import annotations
2+
23
import logging
4+
from collections.abc import Iterable, Sequence
35
from datetime import date
46
from functools import partial
5-
import torch
6-
from typing import Any, Optional, TypeVar, Union, List
7-
from collections.abc import Iterable, Sequence
8-
from tqdm import tqdm
97
from itertools import islice
8+
from typing import Any, List, Optional, TypeVar, Union
9+
1010
import numpy as np
11+
import torch
12+
from tqdm import tqdm
1113

1214
import seb
13-
from seb.interfaces.model import LazyLoadEncoder, ModelMeta, SebModel, Encoder
15+
from seb.interfaces.model import Encoder, LazyLoadEncoder, ModelMeta, SebModel
1416
from seb.interfaces.task import Task
1517
from seb.registries import models
1618

17-
1819
logger = logging.getLogger(__name__)
1920
T = TypeVar("T")
2021

@@ -95,12 +96,12 @@ def __init__(
9596
):
9697
logger.info("Started loading LLM2Vec model")
9798
try:
98-
from llm2vec import LLM2Vec
99+
from llm2vec import LLM2Vec # type: ignore[import]
99100
except ImportError:
100101
raise ImportError("To use the LLM2Vec models `llm2vec` is required. Please install it with `pip seb[llm2vec].")
101102
extra_kwargs = {}
102103
try:
103-
import flash_attn # noqa
104+
import flash_attn # type: ignore[import]
104105

105106
extra_kwargs["attn_implementation"] = "flash_attention_2"
106107
except ImportError:
@@ -149,32 +150,6 @@ def encode(
149150
return torch.cat(batched_embeddings).numpy()
150151

151152

152-
@models.register("TTC-L2V-supervised-da-1")
153-
def create_llm2vec_da_mntp_ttc_supervised() -> SebModel:
154-
base_model = "jealk/llm2vec-da-mntp"
155-
peft_model = "jealk/TTC-L2V-supervised-1"
156-
meta = ModelMeta(
157-
name="TTC-L2V-supervised-da-1",
158-
huggingface_name=peft_model,
159-
reference=f"https://huggingface.co/{peft_model}",
160-
languages=["da"],
161-
open_source=True,
162-
embedding_size=4096,
163-
architecture="LLM2Vec",
164-
release_date=date(2024, 12, 20),
165-
)
166-
partial_model = partial(
167-
LLM2VecModel,
168-
base_model_name_or_path=base_model,
169-
peft_model_name_or_path=peft_model,
170-
max_length=8192,
171-
)
172-
return SebModel(
173-
encoder=LazyLoadEncoder(partial_model),
174-
meta=meta,
175-
)
176-
177-
178153
@models.register("TTC-L2V-unsupervised-da-1")
179154
def create_llm2vec_da_mntp_ttc_unsupervised() -> SebModel:
180155
base_model = "jealk/llm2vec-da-mntp"

tests/dummy_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import numpy as np
24

35
import seb

tests/dummy_task.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
from datetime import datetime
24
from typing import Any
35

@@ -52,9 +54,7 @@ def evaluate(self, model: seb.Encoder) -> seb.TaskResult:
5254

5355

5456
def create_test_raise_error_task() -> seb.Task:
55-
"""
56-
Note this task is not registered as it will cause errrors in other tests.
57-
"""
57+
"""Note this task is not registered as it will cause errrors in other tests."""
5858

5959
class TestTaskWithError(TestTask):
6060
name = "test raise error task"

0 commit comments

Comments
 (0)