Skip to content

Commit e410890

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 6246a3c commit e410890

File tree

7 files changed

+9
-9
lines changed

7 files changed

+9
-9
lines changed

ai_toolkit/datasets/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def get_dataset_initializer(dataset_name: str) -> DatasetLoader:
1212
"""Retrieves class initializer from its string name."""
1313
if not hasattr(sys.modules[__name__], dataset_name):
1414
raise RuntimeError(f"Dataset class {dataset_name} not found in datasets/")
15-
return cast(DatasetLoader, getattr(sys.modules[__name__], dataset_name)())
15+
return cast("DatasetLoader", getattr(sys.modules[__name__], dataset_name)())
1616

1717

1818
__all__ = (

ai_toolkit/metrics/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def get_metric_initializer(metric_name: str) -> type[Metric]:
1515
"""Retrieves class initializer from its string name."""
1616
if not hasattr(sys.modules[__name__], metric_name):
1717
raise RuntimeError(f"Metric {metric_name} not found in metrics folder.")
18-
return cast(type[Metric], getattr(sys.modules[__name__], metric_name))
18+
return cast("type[Metric]", getattr(sys.modules[__name__], metric_name))
1919

2020

2121
__all__ = (

ai_toolkit/metrics/accuracy.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
class Accuracy(Metric):
99
def __repr__(self) -> str:
10-
return f"{self.name}: {100. * self.value:.2f}%"
10+
return f"{self.name}: {100.0 * self.value:.2f}%"
1111

1212
@staticmethod
1313
def calculate_accuracy(output: torch.Tensor, target: torch.Tensor) -> float:

ai_toolkit/metrics/f1_score.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def calculate_f1_score(
3030
f1 = 2 * (precision * recall) / (precision + recall + eps)
3131
f1 = f1.clamp(min=eps, max=1 - eps)
3232
f1_score = 1 - f1.mean()
33-
return cast(float, f1_score)
33+
return cast("float", f1_score)
3434

3535
def update(self, val_dict: SimpleNamespace) -> float:
3636
y_pred, y_true = val_dict.output, val_dict.target

ai_toolkit/metrics/loss.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,4 @@ def update(self, val_dict: SimpleNamespace) -> float:
1010
self.epoch_avg += loss * val_dict.batch_size
1111
self.running_avg += loss * val_dict.batch_size
1212
self.num_examples += val_dict.batch_size
13-
return cast(float, loss)
13+
return cast("float", loss)

ai_toolkit/models/__init__.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@ def get_model_initializer(model_name: str) -> type[nn.Module]:
1919
"""Retrieves class initializer from its string name."""
2020
if not hasattr(sys.modules[__name__], model_name):
2121
raise RuntimeError(f"Model class {model_name} not found in models/")
22-
return cast(type[nn.Module], getattr(sys.modules[__name__], model_name))
22+
return cast("type[nn.Module]", getattr(sys.modules[__name__], model_name))
2323

2424

2525
__all__ = (
2626
"BasicCNN",
27-
"DenseNet",
2827
"BasicLSTM",
29-
"MaskRCNN",
3028
"BasicRNN",
29+
"DenseNet",
30+
"MaskRCNN",
3131
"get_model_initializer",
3232
)

ai_toolkit/test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def test_model(
4646
test_loss /= test_len
4747
print(
4848
f"\nTest set: Average loss: {test_loss:.4f},",
49-
f"Accuracy: {correct}/{test_len} ({100. * correct / test_len:.2f}%)\n",
49+
f"Accuracy: {correct}/{test_len} ({100.0 * correct / test_len:.2f}%)\n",
5050
)
5151

5252

0 commit comments

Comments
 (0)