Skip to content

Commit 36386f1

Browse files
authored
Merge pull request #319 from CrudeDiatribe/no-pickle-importer
No pickle importer
2 parents 4e392fe + 36f81a5 commit 36386f1

File tree

2 files changed

+344
-45
lines changed

2 files changed

+344
-45
lines changed

backends/model_converter/convert_model.py

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,57 @@
1-
from fake_torch import fake_torch_load_zipped
21
import json
32
import numpy as np
43
from constants import SD_SHAPES, _ALPHAS_CUMPROD
5-
import sys
4+
import sys, getopt
65

76
# python convert_model.py "/Users/divamgupta/Downloads/hollie-mengert.ckpt" "/Users/divamgupta/Downloads/hollie-mengert.tdict"
87

98
# pyinstaller convert_model.py --onefile --noconfirm --clean # build using intel machine so that its cross platform lol
109

11-
checkpoint_filename = sys.argv[1]
12-
out_filename = sys.argv[2]
10+
unpickle = False
11+
12+
try:
13+
optlist, args = getopt.getopt(sys.argv[1:], "hu", ["help", "unpickle"])
14+
except getopt.GetoptError as err:
15+
print(err)
16+
#usage()
17+
sys.exit(2)
18+
for o, a in optlist:
19+
if o in ("-h", "--help"):
20+
usage()
21+
sys.exit()
22+
elif o in ("-u", "--unpickle"):
23+
unpickle = True
24+
else:
25+
assert False, "unhandled option"
26+
27+
def usage():
28+
print("\nConverts .cpkt model files into .tdict model files for Diffusion Bee")
29+
print("\npython3 convert_py [--unpickle] input.ckpt output.tdict")
30+
print("\tNormal use.")
31+
print("\n\t--unpickle")
32+
print("\t\tWill use unpickling to extract the model, please use with caution as malicious code")
33+
print("\t\tcan be hidden in the .ckpt file, executed by unpickling. Without this option, the pickle")
34+
print("\t\tinside the .ckpt will instead be decompiled and the weights extracted from that with")
35+
print("\t\tno arbitrary code execution.")
36+
print("\n\tPlease report any errors on the Diffusion Bee GitHub project or the official Discord server.")
37+
print("\npython3 convert_py --help")
38+
print("\tDisplays this message")
39+
40+
if len(args) != 2:
41+
print("Incorrect number of arguments")
42+
usage()
43+
sys.exit(2)
44+
45+
checkpoint_filename = args[0]
46+
out_filename = args[1]
47+
48+
if unpickle:
49+
from fake_torch import fake_torch_load_zipped
50+
torch_weights = fake_torch_load_zipped(open(checkpoint_filename, "rb"))
51+
else:
52+
from fake_torch import extract_weights_from_checkpoint
53+
torch_weights = extract_weights_from_checkpoint(open(checkpoint_filename, "rb"))
54+
1355

1456
#TODO add MD5s
1557

@@ -18,7 +60,6 @@
1860

1961
s = 24
2062

21-
torch_weights = fake_torch_load_zipped(open(checkpoint_filename, "rb"))
2263
keys_info = {}
2364
out_file = open( out_filename , "wb")
2465

0 commit comments

Comments
 (0)