-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtransform_pytorch.py
75 lines (66 loc) · 2.48 KB
/
transform_pytorch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# transform_pytorch.py
import torch
import torch.nn as nn
class TransformNet(nn.Module):
def __init__(self):
super(TransformNet, self).__init__()
self.conv1 = ConvLayer(3, 32, 9, 1)
self.conv2 = ConvLayer(32, 64, 3, 2)
self.conv3 = ConvLayer(64, 128, 3, 2)
self.resid1 = ResidualBlock(128)
self.resid2 = ResidualBlock(128)
self.resid3 = ResidualBlock(128)
self.resid4 = ResidualBlock(128)
self.resid5 = ResidualBlock(128)
self.conv_t1 = ConvTransposeLayer(128, 64, 3, 2)
self.conv_t2 = ConvTransposeLayer(64, 32, 3, 2)
self.conv_t3 = ConvLayer(32, 3, 9, 1, relu=False)
def forward(self, x):
y = x / 255.0
y = self.conv1(y)
y = self.conv2(y)
y = self.conv3(y)
y = self.resid1(y)
y = self.resid2(y)
y = self.resid3(y)
y = self.resid4(y)
y = self.resid5(y)
y = self.conv_t1(y)
y = self.conv_t2(y)
y = self.conv_t3(y)
y = torch.tanh(y)
y = x + y
return torch.tanh(y) * 127.5 + 255. / 2
class ConvLayer(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, relu=True):
super(ConvLayer, self).__init__()
padding = kernel_size // 2
self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
self.instance_norm = nn.InstanceNorm2d(out_channels, affine=True)
self.relu = relu
def forward(self, x):
x = self.conv2d(x)
x = self.instance_norm(x)
if self.relu:
x = torch.relu(x)
return x
class ConvTransposeLayer(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride):
super(ConvTransposeLayer, self).__init__()
padding = kernel_size // 2
self.conv2d_transpose = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, output_padding=stride-1)
self.instance_norm = nn.InstanceNorm2d(out_channels, affine=True)
def forward(self, x):
x = self.conv2d_transpose(x)
x = self.instance_norm(x)
return torch.relu(x)
class ResidualBlock(nn.Module):
def __init__(self, channels):
super(ResidualBlock, self).__init__()
self.conv1 = ConvLayer(channels, channels, 3, 1)
self.conv2 = ConvLayer(channels, channels, 3, 1, relu=False)
def forward(self, x):
residual = x
x = self.conv1(x)
x = self.conv2(x)
return x + residual