Skip to content

Commit ca56962

Browse files
authored
Merge pull request #371 from DrHazemAli/main
Enhance LoraManager to Support .tar Files
2 parents 9668537 + 224e848 commit ca56962

File tree

1 file changed

+27
-15
lines changed

1 file changed

+27
-15
lines changed

fooocusapi/utils/lora_manager.py

+27-15
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,12 @@
1-
"""
2-
Manager loras from url
3-
4-
@author: TechnikMax
5-
@github: https://github.com/TechnikMax
6-
"""
71
import hashlib
82
import os
93
import requests
10-
4+
import tarfile
115

126
def _hash_url(url):
137
"""Generates a hash value for a given URL."""
148
return hashlib.md5(url.encode('utf-8')).hexdigest()
159

16-
1710
class LoraManager:
1811
"""
1912
Manager loras from url
@@ -26,29 +19,48 @@ def __init__(self):
2619

2720
def _download_lora(self, url):
2821
"""
29-
Downloads a LoRa from a URL and saves it in the cache.
22+
Downloads a LoRa from a URL, saves it in the cache, and if it's a .tar file, extracts it and returns the .safetensors file.
3023
"""
3124
url_hash = _hash_url(url)
32-
filepath = os.path.join(self.cache_dir, f"{url_hash}.safetensors")
33-
file_name = f"{url_hash}.safetensors"
25+
file_ext = url.split('.')[-1]
26+
filepath = os.path.join(self.cache_dir, f"{url_hash}.{file_ext}")
3427

3528
if not os.path.exists(filepath):
36-
print(f"start download for: {url}")
29+
print(f"Start download for: {url}")
3730

3831
try:
3932
response = requests.get(url, timeout=10, stream=True)
4033
response.raise_for_status()
4134
with open(filepath, 'wb') as f:
4235
for chunk in response.iter_content(chunk_size=8192):
4336
f.write(chunk)
44-
print(f"Download successfully, saved as {file_name}")
37+
38+
if file_ext == "tar":
39+
print("Extracting the tar file...")
40+
with tarfile.open(filepath, 'r:*') as tar:
41+
tar.extractall(path=self.cache_dir)
42+
print("Extraction completed.")
43+
return self._find_safetensors_file(self.cache_dir)
4544

45+
print(f"Download successfully, saved as {filepath}")
4646
except Exception as e:
47-
raise Exception(f"error downloading {url}: {e}") from e
47+
raise Exception(f"Error downloading {url}: {e}") from e
4848

4949
else:
5050
print(f"LoRa already downloaded {url}")
51-
return file_name
51+
52+
return filepath
53+
54+
def _find_safetensors_file(self, directory):
55+
"""
56+
Finds the first .safetensors file in the specified directory.
57+
"""
58+
print("Searching for .safetensors file.")
59+
for root, dirs, files in os.walk(directory):
60+
for file in files:
61+
if file.endswith('.safetensors'):
62+
return os.path.join(root, file)
63+
raise FileNotFoundError("No .safetensors file found in the extracted files.")
5264

5365
def check(self, urls):
5466
"""Manages the specified LoRAs: downloads missing ones and returns their file names."""

0 commit comments

Comments
 (0)