This repository has been archived by the owner on Jun 23, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_advGAN_model.py
63 lines (55 loc) · 2.7 KB
/
train_advGAN_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# Modified from https://github.com/mathcbc/advGAN_pytorch/blob/master/main.py
import torch
from advGAN import AdvGAN
from target_models import MNIST_target_net
import utils
if __name__ == "__main__":
training_parameters = {
"EPOCHS": 60,
"BATCH_SIZE": 128,
"LEARNING_RATE": 0.001
}
targeted_model_file_name = './models/MNIST_target_model.pth'
model_num_labels = 10
image_nc = 1
BOX_MIN = 0
BOX_MAX = 1
model_path = './models/'
# Define what device we are using
device = utils.define_device()
# Load the pretrained targeted model
targeted_model = MNIST_target_net().to(device)
targeted_model.load_state_dict(torch.load(targeted_model_file_name))
targeted_model.eval()
# Load MNIST train dataset
train_dataloader, train_data_count = utils.load_mnist(
is_train=True, batch_size=training_parameters["BATCH_SIZE"], shuffle=True)
# Train the AdvGAN model
advGAN = AdvGAN(device, targeted_model, model_num_labels, image_nc,
BOX_MIN, BOX_MAX, training_parameters["LEARNING_RATE"], model_path=model_path)
history = advGAN.train(train_dataloader, training_parameters["EPOCHS"])
# Plots
utils.plot_performance(history["counter"],
data=[history["disc_losses"], history["gen_losses"],
history["perturb_losses"], history["adv_losses"]],
plt_names=["discriminator's mse loss", "generator's mse loss",
"perturbation's loss", "adversarial's loss"],
fig_name="GAN_model_performance",
y_name="loss")
utils.plot_performance(history["counter"],
data=[history["disc_losses"], history["gen_losses"]],
plt_names=["discriminator's mse loss", "generator's mse loss"],
fig_name="discriminator_generator_GAN_model_performance",
y_name="mse loss")
utils.plot_performance(history["counter"],
data=[history["perturb_losses"]],
plt_names=["perturbation's loss"],
fig_name="perturbation_GAN_model_performance",
y_name="perturbation's loss",
colors=['mediumvioletred'])
utils.plot_performance(history["counter"],
data=[history["adv_losses"]],
plt_names=["adversarial's loss"],
fig_name="adversarial_GAN_model_performance",
y_name="adversarial's loss",
colors=['crimson'])