-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnets.py
114 lines (90 loc) · 4.33 KB
/
nets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# -*- coding: utf-8 -*-
"""
Created on Tue Jan 19 13:45:49 2021
@author: Shukla
"""
import torch
import torch.nn as nn
class autoencoder_new(nn.Module):
def __init__(self, Out, feature_size =40, label_size=10, image_size= 32):
super(autoencoder_new,self).__init__()
self.label_size =label_size
self.feature_size = feature_size
self.image_size = image_size
self.Out = Out
self.feat_scaler= nn.Sequential(nn.Linear(self.feature_size, (self.Out*4)*2*2),
nn.LeakyReLU(0.2, inplace= True))
self.class_labels = nn.Sequential(nn.Linear(self.label_size,4), nn.Sigmoid())
if self.image_size ==32:
self.encoder = nn.Sequential(
# 3x32x32
nn.Conv2d(3, self.Out, 4, 2, 1, bias = False),
nn.BatchNorm2d(self.Out),
nn.LeakyReLU(0.2, inplace=True),)
elif self.image_size == 64:
self.encoder = nn.Sequential(
# 3x32x32
nn.Conv2d(3, self.Out, 4, 2, 1, bias = False),
nn.BatchNorm2d(self.Out),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(self.Out, self.Out, 4, 2, 1, bias = False),
nn.BatchNorm2d(self.Out),
nn.LeakyReLU(0.2, inplace=True),)
self.encoder.add_module("Body", nn.Sequential(
# self.Outx16x16
nn.Conv2d(self.Out, self.Out*2, 4, 2, 1, bias = False),
nn.BatchNorm2d(self.Out*2),
nn.LeakyReLU(0.2, inplace=True),
# self.Out*2x8x8
nn.Conv2d(self.Out*2, self.Out*4, 4, 2, 1, bias = False),
nn.BatchNorm2d(self.Out*4),
nn.LeakyReLU(0.2, inplace=True),
# self.Out*4x4x4
nn.Conv2d(self.Out*4, self.Out*4, 4, 2, 1, bias = False),
nn.Sigmoid()))
# self.Out*4x2x2
self.decoder = nn.Sequential(
nn.ConvTranspose2d((self.Out*4)+1, self.Out*4, 4 ,2, 1, bias= False),
nn.BatchNorm2d(self.Out*4),
nn.LeakyReLU(0.2,inplace=True),
# self.Out*4 x4x4
nn.Upsample(scale_factor=2, mode='nearest'),
nn.Conv2d( self.Out*4, self.Out*4, 3, 1, 1, bias= False),
nn.BatchNorm2d(self.Out*4),
nn.LeakyReLU(0.2, inplace=True),
# self.Out*4x8x8
nn.Upsample(scale_factor=2, mode='nearest'),
nn.Conv2d( self.Out*4, self.Out*2, 3, 1, 1, bias= False),
nn.BatchNorm2d(self.Out*2),
nn.LeakyReLU(0.2, inplace=True),
# self.Out*2x16x16
nn.Upsample(scale_factor=2, mode='nearest'),
nn.Conv2d(self.Out*2 ,self.Out, 3, 1, 1, bias= False),
nn.BatchNorm2d(self.Out),
nn.LeakyReLU(0.2, inplace=True),)
# self.Outx32x32
if self.image_size == 32:
self.decoder.add_module("final",nn.Sequential(
nn.Conv2d(self.Out,3, 3, 1, 1, bias= False),
nn.Tanh()))
# 3x32x32
elif self.image_size == 64:
self.decoder.add_module("final",nn.Sequential(
nn.Upsample(scale_factor=2, mode='nearest'),
nn.Conv2d(self.Out, self.Out, 3, 1, 1, bias= False),
nn.BatchNorm2d(self.Out),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(self.Out, 3, 3, 1, 1, bias= False),
nn.Tanh()))
def forward(self, inputs, labels, mode = "normal", train_mode = "ae"):
if train_mode == "ae":
feat = self.encoder(inputs)
elif train_mode == "gan":
feat = self.feat_scaler(inputs).reshape(-1,self.Out*4, 2,2)
if mode == "normal":
labels = nn.Softmax(1)(labels)
elif mode == "one_hot":
labels = labels
labs = self.class_labels(labels).reshape(-1, 1, feat.shape[2],feat.shape[3])
concat = torch.cat((feat,labs),1)
return self.decoder(concat)