diff --git a/src/so_vits_svc_fork/utils.py b/src/so_vits_svc_fork/utils.py index 6efb8321..02e63530 100644 --- a/src/so_vits_svc_fork/utils.py +++ b/src/so_vits_svc_fork/utils.py @@ -452,14 +452,18 @@ def clean_checkpoints( False -> lexicographically delete ckpts """ path_to_models = Path(path_to_models) - name_key = lambda p: int(re.match(r"._(\d+)\.pth", p.name).group(1)) + name_key = lambda p: int(re.match(r"._(\d+)", p.stem).group(1)) time_key = lambda p: p.stat().st_mtime models_sorted = sorted( - path_to_models.glob(r"._(\d+).pth"), key=time_key if sort_by_time else name_key + filter( + lambda p: (p.is_file() and re.match(r"._\d+", p.stem)), + path_to_models.glob("*.pth"), + ), + key=time_key if sort_by_time else name_key, ) - models_sorted_grouped = groupby(models_sorted, lambda p: p.name[0]) + models_sorted_grouped = groupby(models_sorted, lambda p: p.stem[0]) for k, g in models_sorted_grouped: - to_dels = list(g)[n_ckpts_to_keep:] + to_dels = list(g)[:-n_ckpts_to_keep] for to_del in to_dels: if to_del.stem.endswith("_0"): continue