-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathexample_12.ml
142 lines (124 loc) · 4.55 KB
/
example_12.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
(*
* OWL - OCaml Scientific and Engineering Computing
* Copyright (c) 2016-2020 Liang Wang <liang.wang@cl.cam.ac.uk>
*)
open Owl_symbolic_neural_graph
open Owl_symbolic_types
(** InceptionV3 *)
(* Note to specify the defautl value of padding in avgpool etc. *)
let conv2d_bn ?(padding = SAME_UPPER) kernel stride nn =
conv2d ~padding kernel stride nn |> normalisation |> activation Relu
let mix_typ1 in_shape bp_size nn =
let branch1x1 = conv2d_bn [| 64; in_shape; 1; 1 |] [| 1; 1 |] nn in
let branch5x5 =
nn
|> conv2d_bn [| 48; in_shape; 1; 1 |] [| 1; 1 |]
|> conv2d_bn [| 64; 48; 5; 5 |] [| 1; 1 |]
in
let branch3x3dbl =
nn
|> conv2d_bn [| 64; in_shape; 1; 1 |] [| 1; 1 |]
|> conv2d_bn [| 96; 64; 3; 3 |] [| 1; 1 |]
|> conv2d_bn [| 96; 96; 3; 3 |] [| 1; 1 |]
in
let branch_pool =
nn
|> avg_pool2d ~padding:SAME_UPPER [| 3; 3 |] [| 1; 1 |]
|> conv2d_bn [| bp_size; in_shape; 1; 1 |] [| 1; 1 |]
in
concat ~axis:1 [| branch1x1; branch5x5; branch3x3dbl; branch_pool |]
let mix_typ3 nn =
let branch3x3 = conv2d_bn [| 384; 288; 3; 3 |] [| 2; 2 |] ~padding:VALID nn in
let branch3x3dbl =
nn
|> conv2d_bn [| 64; 288; 1; 1 |] [| 1; 1 |]
|> conv2d_bn [| 96; 64; 3; 3 |] [| 1; 1 |]
|> conv2d_bn [| 96; 96; 3; 3 |] [| 2; 2 |] ~padding:VALID
in
let branch_pool = max_pool2d [| 3; 3 |] [| 2; 2 |] ~padding:VALID nn in
concat ~axis:1 [| branch3x3; branch3x3dbl; branch_pool |]
let mix_typ4 size nn =
let branch1x1 = conv2d_bn [| 192; 768; 1; 1 |] [| 1; 1 |] nn in
let branch7x7 =
nn
|> conv2d_bn [| size; 768; 1; 1 |] [| 1; 1 |]
|> conv2d_bn [| size; size; 1; 7 |] [| 1; 1 |]
|> conv2d_bn [| 192; size; 7; 1 |] [| 1; 1 |]
in
let branch7x7dbl =
nn
|> conv2d_bn [| size; 768; 1; 1 |] [| 1; 1 |]
|> conv2d_bn [| size; size; 7; 1 |] [| 1; 1 |]
|> conv2d_bn [| size; size; 1; 7 |] [| 1; 1 |]
|> conv2d_bn [| size; size; 7; 1 |] [| 1; 1 |]
|> conv2d_bn [| 192; size; 1; 7 |] [| 1; 1 |]
in
let branch_pool =
nn
|> avg_pool2d [| 3; 3 |] [| 1; 1 |] ~padding:SAME_UPPER
|> conv2d_bn [| 192; 768; 1; 1 |] [| 1; 1 |]
in
concat ~axis:1 [| branch1x1; branch7x7; branch7x7dbl; branch_pool |]
let mix_typ8 nn =
let branch3x3 =
nn
|> conv2d_bn [| 192; 768; 1; 1 |] [| 1; 1 |]
|> conv2d_bn [| 320; 192; 3; 3 |] [| 2; 2 |] ~padding:VALID
in
let branch7x7x3 =
nn
|> conv2d_bn [| 192; 768; 1; 1 |] [| 1; 1 |]
|> conv2d_bn [| 192; 192; 1; 7 |] [| 1; 1 |]
|> conv2d_bn [| 192; 192; 7; 1 |] [| 1; 1 |]
|> conv2d_bn [| 192; 192; 3; 3 |] [| 2; 2 |] ~padding:VALID
in
let branch_pool = max_pool2d [| 3; 3 |] [| 2; 2 |] ~padding:VALID nn in
concat ~axis:1 [| branch3x3; branch7x7x3; branch_pool |]
let mix_typ9 input nn =
let branch1x1 = conv2d_bn [| 320; input; 1; 1 |] [| 1; 1 |] nn in
let branch3x3 = conv2d_bn [| 384; input; 1; 1 |] [| 1; 1 |] nn in
let branch3x3_1 = branch3x3 |> conv2d_bn [| 384; 384; 1; 3 |] [| 1; 1 |] in
let branch3x3_2 = branch3x3 |> conv2d_bn [| 384; 384; 3; 1 |] [| 1; 1 |] in
let branch3x3 = concat ~axis:1 [| branch3x3_1; branch3x3_2 |] in
let branch3x3dbl =
nn
|> conv2d_bn [| 448; input; 1; 1 |] [| 1; 1 |]
|> conv2d_bn [| 384; 448; 3; 3 |] [| 1; 1 |]
in
let branch3x3dbl_1 = branch3x3dbl |> conv2d_bn [| 384; 384; 1; 3 |] [| 1; 1 |] in
let branch3x3dbl_2 = branch3x3dbl |> conv2d_bn [| 384; 384; 3; 1 |] [| 1; 1 |] in
let branch3x3dbl = concat ~axis:1 [| branch3x3dbl_1; branch3x3dbl_2 |] in
let branch_pool =
nn
|> avg_pool2d ~padding:SAME_UPPER [| 3; 3 |] [| 1; 1 |]
|> conv2d_bn [| 192; input; 1; 1 |] [| 1; 1 |]
in
concat ~axis:1 [| branch1x1; branch3x3; branch3x3dbl; branch_pool |]
let make_network batch img_size =
input [| batch; 3; img_size; img_size |]
|> conv2d_bn [| 32; 3; 3; 3 |] [| 2; 2 |] ~padding:VALID
|> conv2d_bn [| 32; 32; 3; 3 |] [| 1; 1 |] ~padding:VALID
|> conv2d_bn [| 64; 32; 3; 3 |] [| 1; 1 |]
|> max_pool2d [| 3; 3 |] [| 2; 2 |] ~padding:VALID
|> conv2d_bn [| 80; 64; 1; 1 |] [| 1; 1 |] ~padding:VALID
|> conv2d_bn [| 192; 80; 3; 3 |] [| 1; 1 |] ~padding:VALID
|> max_pool2d [| 3; 3 |] [| 2; 2 |] ~padding:VALID
|> mix_typ1 192 32
|> mix_typ1 256 64
|> mix_typ1 288 64
|> mix_typ3
|> mix_typ4 128
|> mix_typ4 160
|> mix_typ4 160
|> mix_typ4 192
|> mix_typ8
|> mix_typ9 1280
|> mix_typ9 2048
|> global_avg_pool2d
|> linear 1000
|> activation (Softmax 1)
|> get_network
let _ =
let nn = make_network 1 299 in
let onnx_graph = Owl_symbolic_engine_onnx.of_symbolic nn in
Owl_symbolic_engine_onnx.save onnx_graph "test.onnx"