1
1
import os
2
2
import math
3
3
import random
4
- import functools
5
4
import numpy as np
6
5
import paddle
7
6
from pathlib import Path
8
7
from PIL import Image , ImageEnhance
9
8
from paddle .io import Dataset
10
9
from src .quantization .utils import ArgumentsParser
10
+ import ast
11
+ from paddle .io import DataLoader
12
+ from paddleslim .quant import quant_post_static
11
13
12
14
13
15
random .seed (0 )
19
21
THREAD = 16
20
22
BUF_SIZE = 10240
21
23
22
- DATA_DIR = r'D:\ws\dl-benchmark\imagenet'
23
-
24
24
img_mean = np .array ([0.485 , 0.456 , 0.406 ]).reshape ((3 , 1 , 1 ))
25
25
img_std = np .array ([0.229 , 0.224 , 0.225 ]).reshape ((3 , 1 , 1 ))
26
26
@@ -143,78 +143,15 @@ def process_image(sample,
143
143
return [img ]
144
144
145
145
146
- def _reader_creator (file_list ,
147
- mode ,
148
- shuffle = False ,
149
- color_jitter = False ,
150
- rotate = False ,
151
- data_dir = DATA_DIR ,
152
- crop_size = DATA_DIM ,
153
- resize_size = RESIZE_DIM ,
154
- batch_size = 1 ):
155
- def reader ():
156
- try :
157
- with open (file_list ) as flist :
158
- full_lines = [line .strip () for line in flist ]
159
- if shuffle :
160
- np .random .shuffle (full_lines )
161
- lines = full_lines
162
- for line in lines :
163
- if mode == 'train' or mode == 'val' :
164
- img_path , label = line .split ()
165
- img_path = os .path .join (data_dir , img_path ) + '.JPEG'
166
- yield img_path , int (label )
167
- elif mode == 'test' :
168
- img_path = os .path .join (data_dir , line )
169
- yield [img_path ]
170
- except Exception as e :
171
- print (f'Reader failed!\n { str (e )} ' )
172
- os ._exit (1 )
173
-
174
- mapper = functools .partial (
175
- process_image ,
176
- mode = mode ,
177
- color_jitter = color_jitter ,
178
- rotate = rotate ,
179
- crop_size = crop_size ,
180
- resize_size = resize_size )
181
-
182
- return paddle .reader .xmap_readers (mapper , reader , THREAD , BUF_SIZE )
183
-
184
-
185
- def train (data_dir = DATA_DIR ):
186
- file_list = os .path .join (data_dir , 'train_loc.txt' )
187
- return _reader_creator (
188
- file_list ,
189
- 'train' ,
190
- shuffle = True ,
191
- color_jitter = False ,
192
- rotate = False ,
193
- data_dir = data_dir )
194
-
195
-
196
- def val (data_dir = DATA_DIR ):
197
- file_list = os .path .join (data_dir , 'val.txt' )
198
- return _reader_creator (file_list , 'val' , shuffle = False , data_dir = data_dir )
199
-
200
-
201
- def test (data_dir = DATA_DIR ):
202
- file_list = os .path .join (data_dir , 'test.txt' )
203
- return _reader_creator (file_list , 'test' , shuffle = False , data_dir = data_dir )
204
-
205
-
206
- class ImageNetDataset (Dataset ):
207
- def __init__ (self ,
208
- data_dir = DATA_DIR ,
209
- mode = 'train' ,
210
- crop_size = DATA_DIM ,
211
- resize_size = RESIZE_DIM ):
212
- super (ImageNetDataset , self ).__init__ ()
213
- self .data_dir = data_dir
214
- self .crop_size = crop_size
215
- self .resize_size = resize_size
146
+ class PaddleDatasetReader (Dataset ):
147
+ def __init__ (self , args , mode = 'train' ):
148
+ super (PaddleDatasetReader , self ).__init__ ()
149
+ self .data_dir = ast .literal_eval (args ['Path' ])
150
+ self .crop_size = ast .literal_eval (args ['CropResolution' ])
151
+ self .resize_size = ast .literal_eval (args ['ImageResolution' ])
216
152
self .mode = mode
217
- self .dataset = list (Path (data_dir ).glob ('*' ))
153
+ self .batch_size = int (args ['BatchSize' ])
154
+ self .dataset = list (Path (self .data_dir ).glob ('*' ))
218
155
random .shuffle (self .dataset )
219
156
self .dataset_iter = iter (self .dataset )
220
157
@@ -249,7 +186,50 @@ def __getitem__(self, index):
249
186
return data
250
187
251
188
def __len__ (self ):
252
- return len (self .data )
189
+ return len (self .dataset )
190
+
191
+
192
+ class PaddleQuantizationProcess :
193
+ def __init__ (self , log , model_reader , dataset , quant_params ):
194
+ self .log = log
195
+ self .model_reader = model_reader
196
+ self .dataset = dataset
197
+ self .quant_params = quant_params
198
+
199
+ def transform_fn (self ):
200
+ for data in self .dataset :
201
+ yield [data .astype (np .float32 )]
202
+
203
+ def quantization_tflite (self ):
204
+ paddle .enable_static ()
205
+ place = paddle .CPUPlace ()
206
+ exe = paddle .static .Executor (place )
207
+
208
+ data_loader = DataLoader (
209
+ self .dataset ,
210
+ places = place ,
211
+ feed_list = [self .quant_params .image ],
212
+ drop_last = False ,
213
+ return_list = False ,
214
+ batch_size = self .dataset .batch_size ,
215
+ shuffle = False )
216
+
217
+ quant_post_static (
218
+ executor = exe ,
219
+ model_dir = self .model_reader .model_dir ,
220
+ quantize_model_path = self .quant_params .save_dir ,
221
+ data_loader = data_loader ,
222
+ model_filename = self .model_reader .model_filename ,
223
+ params_filename = self .model_reader .params_filename ,
224
+ batch_size = self .dataset .batch_size ,
225
+ batch_nums = 10 ,
226
+ algo = 'avg' ,
227
+ round_type = 'round' ,
228
+ hist_percent = 0.9999 ,
229
+ is_full_quantize = False ,
230
+ bias_correction = False ,
231
+ onnx_format = False )
232
+
253
233
254
234
255
235
class PaddleModelReader (ArgumentsParser ):
@@ -259,16 +239,33 @@ def __init__(self, log):
259
239
def _get_arguments (self ):
260
240
self ._log .info ('Parsing model arguments.' )
261
241
self .path_prefix = self .args ['PathPrefix' ]
262
- self ._read_model ()
242
+ self .model_dir = self .args ['ModelDir' ]
243
+ self .model_filename = self .args ['ModelFileName' ]
244
+ self .params_filename = self .args ['ParamsFileName' ]
263
245
264
246
def dict_for_iter_log (self ):
265
247
return {
266
248
'Model path prefix' : self .path_prefix ,
267
249
}
268
250
269
- def _read_model (self ):
270
- paddle .enable_static ()
271
- place = paddle .CPUPlace ()
272
- exe = paddle .static .Executor (place )
273
- self .inference_program , self .feed_target_names , self .fetch_targets = paddle .static .load_inference_model (
274
- self .path_prefix , exe )
251
+
252
+ class PaddleQuantParamReader (ArgumentsParser ):
253
+ def __init__ (self , log ):
254
+ super ().__init__ (log )
255
+
256
+ def dict_for_iter_log (self ):
257
+ return {
258
+ 'InputShape' : self .input_shape ,
259
+ 'InputName' : self .input_name ,
260
+ 'SaveDir' : self .save_dir ,
261
+ }
262
+
263
+ def _get_arguments (self ):
264
+ self .image_shape = ast .literal_eval (self .args ['InputShape' ])
265
+ self .image = paddle .static .data (name = self .args ['InputName' ], shape = [None ] + self .image_shape , dtype = 'float32' )
266
+ self .input_shape = self .args ['InputShape' ]
267
+ self .input_name = self .args ['InputName' ]
268
+ self .save_dir = self .args ['SaveDir' ]
269
+
270
+ def _convert_to_list_of_tf_objects (self , keys , dictionary ):
271
+ return [dictionary [key ] for key in keys ]
0 commit comments