-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathjax_resnet.py
352 lines (310 loc) · 9.79 KB
/
jax_resnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
from typing import Any, Callable, List, Optional, Sequence, Type, Union
import equinox as eqx
import equinox.nn as nn
import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jrandom
from jaxtyping import Array
def _conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1, key=None):
"""3x3 convolution with padding"""
return nn.Conv2d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=dilation,
groups=groups,
use_bias=False,
dilation=dilation,
key=key,
)
def _conv1x1(in_planes, out_planes, stride=1, key=None):
"""1x1 convolution"""
return nn.Conv2d(
in_planes, out_planes, kernel_size=1, stride=stride, use_bias=False, key=key
)
class _ResNetBasicBlock(eqx.nn.StatefulLayer):
expansion: int
conv1: eqx.Module
bn1: eqx.Module
relu: Callable
conv2: eqx.Module
bn2: eqx.Module
downsample: eqx.Module
stride: int
def __init__(
self,
inplanes,
planes,
stride=1,
downsample=None,
groups=1,
base_width=64,
dilation=1,
norm_layer=None,
key=None,
):
super().__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm
if groups != 1 or base_width != 64:
raise ValueError("BasicBlock only supports groups=1 and base_width=64")
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
keys = jrandom.split(key, 2)
self.expansion = 1
self.conv1 = _conv3x3(inplanes, planes, stride, key=keys[0])
self.bn1 = norm_layer(planes, axis_name="batch")
self.relu = jnn.relu
self.conv2 = _conv3x3(planes, planes, key=keys[1])
self.bn2 = norm_layer(planes, axis_name="batch")
if downsample:
self.downsample = downsample
else:
self.downsample = nn.Identity()
self.stride = stride
def __call__(
self,
x: Array,
state: nn.State,
*,
key: Optional["jax.random.PRNGKey"] = None,
) -> Array:
out = self.conv1(x)
out, state = self.bn1(out, state)
out = self.relu(out)
out = self.conv2(out)
out, state = self.bn2(out, state)
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out, state
class _ResNetBottleneck(eqx.nn.StatefulLayer):
expansion: int
conv1: eqx.Module
bn1: eqx.Module
conv2: eqx.Module
bn2: eqx.Module
conv3: eqx.Module
bn3: eqx.Module
relu: Callable
downsample: eqx.Module
stride: int
has_downsample: bool
def __init__(
self,
inplanes,
planes,
stride=1,
downsample=None,
groups=1,
base_width=64,
dilation=1,
norm_layer=None,
key=None,
):
super(_ResNetBottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm
self.expansion = 4
keys = jrandom.split(key, 3)
width = int(planes * (base_width / 64.0)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = _conv1x1(inplanes, width, key=keys[0])
self.bn1 = norm_layer(width, axis_name="batch")
self.conv2 = _conv3x3(width, width, stride, groups, dilation, key=keys[1])
self.bn2 = norm_layer(width, axis_name="batch")
self.conv3 = _conv1x1(width, planes * self.expansion, key=keys[2])
self.bn3 = norm_layer(planes * self.expansion, axis_name="batch")
self.relu = jnn.relu
if downsample:
self.has_downsample = True
self.downsample = downsample
else:
self.has_downsample = False
self.downsample = nn.Identity()
self.stride = stride
def __call__(
self,
x: Array,
state: nn.State,
*,
key: Optional["jax.random.PRNGKey"] = None,
):
out = self.conv1(x)
out, state = self.bn1(out, state)
out = self.relu(out)
out = self.conv2(out)
out, state = self.bn2(out, state)
out = self.relu(out)
out = self.conv3(out)
out, state = self.bn3(out, state)
if self.has_downsample:
identity, state = self.downsample(x, state)
else:
identity = x
out += identity
out = self.relu(out)
return out, state
EXPANSIONS = {_ResNetBasicBlock: 1, _ResNetBottleneck: 4}
class ResNet(eqx.Module):
"""A simple port of `torchvision.models.resnet`"""
inplanes: int
dilation: int
groups: Sequence[int]
base_width: int
conv1: eqx.Module
bn1: eqx.Module
relu: jnn.relu
maxpool: eqx.Module
layer1: eqx.Module
layer2: eqx.Module
layer3: eqx.Module
layer4: eqx.Module
avgpool: eqx.Module
fc: eqx.Module
def __init__(
self,
block: Type[Union["_ResNetBasicBlock", "_ResNetBottleneck"]],
layers: List[int],
num_classes: int = 1000,
groups: int = 1,
width_per_group: int = 64,
replace_stride_with_dilation: List[bool] = None,
norm_layer: Any = None,
*,
key: Optional["jax.random.PRNGKey"] = None,
):
super(ResNet, self).__init__()
norm_layer = nn.BatchNorm
if key is None:
key = jrandom.PRNGKey(0)
keys = jrandom.split(key, 6)
self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError(
"replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation)
)
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(
3,
self.inplanes,
kernel_size=7,
stride=2,
padding=3,
use_bias=False,
key=keys[0],
)
self.bn1 = norm_layer(input_size=self.inplanes, axis_name="batch")
self.relu = jnn.relu
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0], norm_layer, key=keys[1])
self.layer2 = self._make_layer(
block,
128,
layers[1],
norm_layer,
stride=2,
dilate=replace_stride_with_dilation[0],
key=keys[2],
)
self.layer3 = self._make_layer(
block,
256,
layers[2],
norm_layer,
stride=2,
dilate=replace_stride_with_dilation[1],
key=keys[3],
)
self.layer4 = self._make_layer(
block,
512,
layers[3],
norm_layer,
stride=2,
dilate=replace_stride_with_dilation[2],
key=keys[4],
)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * EXPANSIONS[block], num_classes, key=keys[5])
def _make_layer(
self, block, planes, blocks, norm_layer, stride=1, dilate=False, key=None
):
keys = jrandom.split(key, blocks + 1)
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * EXPANSIONS[block]:
downsample = nn.Sequential(
[
_conv1x1(
self.inplanes, planes * EXPANSIONS[block], stride, key=keys[0]
),
norm_layer(planes * EXPANSIONS[block], axis_name="batch"),
]
)
layers = []
layers.append(
block(
self.inplanes,
planes,
stride,
downsample,
self.groups,
self.base_width,
previous_dilation,
norm_layer,
key=keys[1],
)
)
self.inplanes = planes * EXPANSIONS[block]
for block_idx in range(1, blocks):
layers.append(
block(
self.inplanes,
planes,
groups=self.groups,
base_width=self.base_width,
dilation=self.dilation,
norm_layer=norm_layer,
key=keys[block_idx + 1],
)
)
return nn.Sequential(layers)
def __call__(self, x: Array, state: nn.State) -> Array:
x = self.conv1(x)
x, state = self.bn1(x, state)
x = self.relu(x)
x = self.maxpool(x)
x, state = self.layer1(x, state)
x, state = self.layer2(x, state)
x, state = self.layer3(x, state)
x, state = self.layer4(x, state)
x = self.avgpool(x)
x = jnp.ravel(x)
x = self.fc(x)
return x, state
def _resnet(block, layers, **kwargs):
model = ResNet(block, layers, **kwargs)
return model
def resnet18(**kwargs) -> ResNet:
model = _resnet(_ResNetBasicBlock, [2, 2, 2, 2], **kwargs)
return model
def resnet34(**kwargs) -> ResNet:
model = _resnet(_ResNetBasicBlock, [3, 4, 6, 3], **kwargs)
return model
def resnet50(**kwargs) -> ResNet:
model = _resnet(_ResNetBottleneck, [3, 4, 6, 3], **kwargs)
return model