1
- """
2
- Manager loras from url
3
-
4
- @author: TechnikMax
5
- @github: https://github.com/TechnikMax
6
- """
7
1
import hashlib
8
2
import os
9
3
import requests
10
-
4
+ import tarfile
11
5
12
6
def _hash_url (url ):
13
7
"""Generates a hash value for a given URL."""
14
8
return hashlib .md5 (url .encode ('utf-8' )).hexdigest ()
15
9
16
-
17
10
class LoraManager :
18
11
"""
19
12
Manager loras from url
@@ -26,29 +19,48 @@ def __init__(self):
26
19
27
20
def _download_lora (self , url ):
28
21
"""
29
- Downloads a LoRa from a URL and saves it in the cache.
22
+ Downloads a LoRa from a URL, saves it in the cache, and if it's a .tar file, extracts it and returns the .safetensors file .
30
23
"""
31
24
url_hash = _hash_url (url )
32
- filepath = os . path . join ( self . cache_dir , f" { url_hash } .safetensors" )
33
- file_name = f"{ url_hash } .safetensors"
25
+ file_ext = url . split ( '.' )[ - 1 ]
26
+ filepath = os . path . join ( self . cache_dir , f"{ url_hash } .{ file_ext } " )
34
27
35
28
if not os .path .exists (filepath ):
36
- print (f"start download for: { url } " )
29
+ print (f"Start download for: { url } " )
37
30
38
31
try :
39
32
response = requests .get (url , timeout = 10 , stream = True )
40
33
response .raise_for_status ()
41
34
with open (filepath , 'wb' ) as f :
42
35
for chunk in response .iter_content (chunk_size = 8192 ):
43
36
f .write (chunk )
44
- print (f"Download successfully, saved as { file_name } " )
37
+
38
+ if file_ext == "tar" :
39
+ print ("Extracting the tar file..." )
40
+ with tarfile .open (filepath , 'r:*' ) as tar :
41
+ tar .extractall (path = self .cache_dir )
42
+ print ("Extraction completed." )
43
+ return self ._find_safetensors_file (self .cache_dir )
45
44
45
+ print (f"Download successfully, saved as { filepath } " )
46
46
except Exception as e :
47
- raise Exception (f"error downloading { url } : { e } " ) from e
47
+ raise Exception (f"Error downloading { url } : { e } " ) from e
48
48
49
49
else :
50
50
print (f"LoRa already downloaded { url } " )
51
- return file_name
51
+
52
+ return filepath
53
+
54
+ def _find_safetensors_file (self , directory ):
55
+ """
56
+ Finds the first .safetensors file in the specified directory.
57
+ """
58
+ print ("Searching for .safetensors file." )
59
+ for root , dirs , files in os .walk (directory ):
60
+ for file in files :
61
+ if file .endswith ('.safetensors' ):
62
+ return os .path .join (root , file )
63
+ raise FileNotFoundError ("No .safetensors file found in the extracted files." )
52
64
53
65
def check (self , urls ):
54
66
"""Manages the specified LoRAs: downloads missing ones and returns their file names."""
0 commit comments