Skip to content

Commit a783234

Browse files
committed
Initial import
1 parent ecfc96c commit a783234

File tree

11 files changed

+689
-2
lines changed

11 files changed

+689
-2
lines changed

.dockerignore

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
__pycache__
2+
*.py[cod]
3+
4+
*.pth
5+
*.pb
6+
*.pkl

.gitignore

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
__pycache__
2+
*.py[cod]
3+
4+
*.pth
5+
*.pb
6+
*.pkl

AUTHORS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Daniel J. Hofmann
2+
Harsimrat Sandhawalia

Dockerfile

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
FROM ubuntu:18.04
2+
3+
WORKDIR /usr/src/app
4+
5+
ENV LANG="C.UTF-8" LC_ALL="C.UTF-8" PATH="/opt/venv/bin:$PATH" PIP_NO_CACHE_DIR="false"
6+
7+
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
8+
python3 python3-pip python3-venv libglib2.0-0 && \
9+
rm -rf /var/lib/apt/lists/*
10+
11+
COPY requirements.txt .
12+
13+
RUN python3 -m venv /opt/venv && \
14+
python3 -m pip install pip==19.2.3 pip-tools==4.0.0 && \
15+
python3 -m piptools sync
16+
17+
COPY . .

LICENSE renamed to LICENSE.md

File renamed without changes.

Makefile

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
dockerimage ?= das/vmz
2+
dockerfile ?= Dockerfile
3+
srcdir ?= $(shell pwd)
4+
datadir ?= $(shell pwd)
5+
6+
install:
7+
@docker build -t $(dockerimage) -f $(dockerfile) .
8+
9+
i: install
10+
11+
12+
update:
13+
@docker build -t $(dockerimage) -f $(dockerfile) . --pull --no-cache
14+
15+
u: update
16+
17+
18+
run:
19+
@docker run -it --rm -v $(srcdir):/usr/src/app/ \
20+
-v $(datadir):/data \
21+
--entrypoint=/bin/bash $(dockerimage)
22+
23+
r: run
24+
25+
26+
webcam:
27+
@docker run -it --rm -v $(srcdir):/usr/src/app/ \
28+
-v $(datadir):/data \
29+
--device=/dev/video0 \
30+
--entrypoint=/bin/bash $(dockerimage)
31+
32+
w: webcam
33+
34+
35+
.PHONY: install i run r update u webcam w

README.md

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,37 @@
1-
# video-resnet
2-
ResNet 3D Conv Video models
1+
# IG65-M PyTorch
2+
3+
Unofficial PyTorch (and ONNX) models and weights for IG65-M pre-trained 3d video architectures.
4+
5+
The official research Caffe2 model and weights are availabe at: https://github.com/facebookresearch/vmz
6+
7+
8+
## Models
9+
10+
| Model | Weights | Input Size | pth | onnx |
11+
| ------------- | ------------------ | ---------- | ----------------------------------------------- | --------------------------------------------- |
12+
| r(2+1)d 34 | IG65-M | 8x112x112 | *r2plus1d_34_clip8_ig65m_from_scratch.pth* | *r2plus1d_34_clip8_ig65m_from_scratch.pb* |
13+
| r(2+1)d 34 | IG65-M + Kinetics | 8x112x112 | *r2plus1d_34_clip8_ft_kinetics_from_ig65m.pth* | *r2plus1d_34_clip8_ft_kinetics_from_ig65m.pb* |
14+
| r(2+1)d 34 | IG65-M | 32x112x112 | NA | NA |
15+
| r(2+1)d 34 | IG65-M + Kinetics | 32x112x112 | *r2plus1d_34_clip32_ft_kinetics_from_ig65m.pth* | r2plus1d_34_clip32_ft_kinetics_from_ig65m.pb |
16+
17+
18+
## Usage
19+
20+
See
21+
- `convert.py` for model conversion
22+
- `extract.py` for feature extraction
23+
24+
We provide converted `.pth` PyTorch weights as artifacts in our Github releases.
25+
26+
27+
## References
28+
- [VMZ: Model Zoo for Video Modeling](https://github.com/facebookresearch/vmz)
29+
- [Kinetics](https://arxiv.org/abs/1705.06950)
30+
- [IG65-M](https://arxiv.org/abs/1905.00561)
31+
32+
33+
## License
34+
35+
Copyright © 2019 MoabitCoin
36+
37+
Distributed under the MIT License (MIT).

convert.py

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
#!/usr/bin/env python3
2+
3+
import pickle
4+
import argparse
5+
from pathlib import Path
6+
7+
import torch
8+
import torch.nn as nn
9+
10+
from torchvision.models.video.resnet import VideoResNet, BasicBlock, R2Plus1dStem, Conv2Plus1D
11+
12+
13+
def r2plus1d_34(num_classes, pretrained=False, progress=False, **kwargs):
14+
model = VideoResNet(block=BasicBlock,
15+
conv_makers=[Conv2Plus1D] * 4,
16+
layers=[3, 4, 6, 3],
17+
stem=R2Plus1dStem)
18+
19+
model.fc = nn.Linear(model.fc.in_features, out_features=num_classes)
20+
21+
# Fix difference in PyTorch vs Caffe2 architecture
22+
# https://github.com/facebookresearch/VMZ/issues/89
23+
model.layer2[0].conv2[0] = Conv2Plus1D(128, 128, 288)
24+
model.layer3[0].conv2[0] = Conv2Plus1D(256, 256, 576)
25+
model.layer4[0].conv2[0] = Conv2Plus1D(512, 512, 1152)
26+
27+
# We need exact Caffe2 momentum for BatchNorm scaling
28+
for m in model.modules():
29+
if isinstance(m, nn.BatchNorm3d):
30+
m.eps = 1e-3
31+
m.momentum = 0.9
32+
33+
return model
34+
35+
36+
def blobs_from_pkl(path):
37+
with path.open(mode="rb") as f:
38+
pkl = pickle.load(f, encoding="latin1")
39+
return pkl["blobs"]
40+
41+
42+
def copy_tensor(data, blobs, name):
43+
tensor = torch.from_numpy(blobs[name])
44+
45+
del blobs[name] # enforce: use at most once
46+
47+
assert data.size() == tensor.size()
48+
assert data.dtype == tensor.dtype
49+
50+
data.copy_(tensor)
51+
52+
53+
def copy_conv(module, blobs, prefix):
54+
assert isinstance(module, nn.Conv3d)
55+
assert module.bias is None
56+
copy_tensor(module.weight.data, blobs, prefix + "_w")
57+
58+
59+
def copy_bn(module, blobs, prefix):
60+
assert isinstance(module, nn.BatchNorm3d)
61+
copy_tensor(module.weight.data, blobs, prefix + "_s")
62+
copy_tensor(module.running_mean.data, blobs, prefix + "_rm")
63+
copy_tensor(module.running_var.data, blobs, prefix + "_riv")
64+
copy_tensor(module.bias.data, blobs, prefix + "_b")
65+
66+
67+
def copy_fc(module, blobs):
68+
assert isinstance(module, nn.Linear)
69+
n = module.out_features
70+
copy_tensor(module.bias.data, blobs, "last_out_L" + str(n) + "_b")
71+
copy_tensor(module.weight.data, blobs, "last_out_L" + str(n) + "_w")
72+
73+
74+
# https://github.com/pytorch/vision/blob/v0.4.0/torchvision/models/video/resnet.py#L174-L188
75+
# https://github.com/facebookresearch/VMZ/blob/6c925c47b7d6545b64094a083f111258b37cbeca/lib/models/r3d_model.py#L233-L275
76+
def copy_stem(module, blobs):
77+
assert isinstance(module, R2Plus1dStem)
78+
assert len(module) == 6
79+
copy_conv(module[0], blobs, "conv1_middle")
80+
copy_bn(module[1], blobs, "conv1_middle_spatbn_relu")
81+
assert isinstance(module[2], nn.ReLU)
82+
copy_conv(module[3], blobs, "conv1")
83+
copy_bn(module[4], blobs, "conv1_spatbn_relu")
84+
assert isinstance(module[5], nn.ReLU)
85+
86+
87+
# https://github.com/pytorch/vision/blob/v0.4.0/torchvision/models/video/resnet.py#L82-L114
88+
def copy_conv2plus1d(module, blobs, i, j):
89+
assert isinstance(module, Conv2Plus1D)
90+
assert len(module) == 4
91+
copy_conv(module[0], blobs, "comp_" + str(i) + "_conv_" + str(j) + "_middle")
92+
copy_bn(module[1], blobs, "comp_" + str(i) + "_spatbn_" + str(j) + "_middle")
93+
assert isinstance(module[2], nn.ReLU)
94+
copy_conv(module[3], blobs, "comp_" + str(i) + "_conv_" + str(j))
95+
96+
97+
# https://github.com/pytorch/vision/blob/v0.4.0/torchvision/models/video/resnet.py#L82-L114
98+
def copy_basicblock(module, blobs, i):
99+
assert isinstance(module, BasicBlock)
100+
101+
assert len(module.conv1) == 3
102+
assert isinstance(module.conv1[0], Conv2Plus1D)
103+
copy_conv2plus1d(module.conv1[0], blobs, i, 1)
104+
assert isinstance(module.conv1[1], nn.BatchNorm3d)
105+
copy_bn(module.conv1[1], blobs, "comp_" + str(i) + "_spatbn_" + str(1))
106+
assert isinstance(module.conv1[2], nn.ReLU)
107+
108+
assert len(module.conv2) == 2
109+
assert isinstance(module.conv2[0], Conv2Plus1D)
110+
copy_conv2plus1d(module.conv2[0], blobs, i, 2)
111+
assert isinstance(module.conv2[1], nn.BatchNorm3d)
112+
copy_bn(module.conv2[1], blobs, "comp_" + str(i) + "_spatbn_" + str(2))
113+
114+
if module.downsample is not None:
115+
assert i in [3, 7, 13]
116+
assert len(module.downsample) == 2
117+
assert isinstance(module.downsample[0], nn.Conv3d)
118+
assert isinstance(module.downsample[1], nn.BatchNorm3d)
119+
copy_conv(module.downsample[0], blobs, "shortcut_projection_" + str(i))
120+
copy_bn(module.downsample[1], blobs, "shortcut_projection_" + str(i) + "_spatbn")
121+
122+
123+
def copy_layer(module, blobs, i):
124+
assert {0: 3, 3: 4, 7: 6, 13: 3}[i] == len(module)
125+
126+
for basicblock in module:
127+
copy_basicblock(basicblock, blobs, i)
128+
i += 1
129+
130+
131+
def init_canary(model):
132+
nan = float("nan")
133+
134+
for m in model.modules():
135+
if isinstance(m, nn.Conv3d):
136+
assert m.bias is None
137+
nn.init.constant_(m.weight, nan)
138+
elif isinstance(m, nn.BatchNorm3d):
139+
nn.init.constant_(m.weight, nan)
140+
nn.init.constant_(m.running_mean, nan)
141+
nn.init.constant_(m.running_var, nan)
142+
nn.init.constant_(m.bias, nan)
143+
elif isinstance(m, nn.Linear):
144+
nn.init.constant_(m.weight, nan)
145+
nn.init.constant_(m.bias, nan)
146+
147+
148+
def check_canary(model):
149+
for m in model.modules():
150+
if isinstance(m, nn.Conv3d):
151+
assert m.bias is None
152+
assert not torch.isnan(m.weight).any()
153+
elif isinstance(m, nn.BatchNorm3d):
154+
assert not torch.isnan(m.weight).any()
155+
assert not torch.isnan(m.running_mean).any()
156+
assert not torch.isnan(m.running_var).any()
157+
assert not torch.isnan(m.bias).any()
158+
elif isinstance(m, nn.Linear):
159+
assert not torch.isnan(m.weight).any()
160+
assert not torch.isnan(m.bias).any()
161+
162+
163+
def main(args):
164+
blobs = blobs_from_pkl(args.pkl)
165+
166+
model = r2plus1d_34(num_classes=args.classes)
167+
168+
init_canary(model)
169+
170+
copy_stem(model.stem, blobs)
171+
172+
layers = [model.layer1, model.layer2, model.layer3, model.layer4]
173+
blocks = [0, 3, 7, 13]
174+
175+
for layer, i in zip(layers, blocks):
176+
copy_layer(layer, blobs, i)
177+
178+
copy_fc(model.fc, blobs)
179+
180+
assert not blobs
181+
check_canary(model)
182+
183+
# Export to pytorch .pth and self-contained onnx .pb files
184+
185+
batch = torch.rand(1, 3, args.frames, 112, 112) # NxCxTxHxW
186+
torch.save(model.state_dict(), args.out.with_suffix(".pth"))
187+
torch.onnx.export(model, batch, args.out.with_suffix(".pb"))
188+
189+
# Check pth roundtrip into fresh model
190+
191+
model = r2plus1d_34(num_classes=args.classes)
192+
model.load_state_dict(torch.load(args.out.with_suffix(".pth")))
193+
194+
195+
if __name__ == "__main__":
196+
parser = argparse.ArgumentParser()
197+
arg = parser.add_argument
198+
199+
arg("pkl", type=Path, help=".pkl file to read the R(2+1)D 34 layer weights from")
200+
arg("out", type=Path, help="prefix to save converted R(2+1)D 34 layer weights to")
201+
arg("--frames", type=int, choices=(8, 32), required=True, help="clip frames for video model")
202+
arg("--classes", type=int, choices=(400, 487), required=True, help="classes in last layer")
203+
204+
main(parser.parse_args())

0 commit comments

Comments
 (0)