Skip to content

Commit db7ef6b

Browse files
committedApr 17, 2025
structure fix
1 parent 583f819 commit db7ef6b

File tree

3 files changed

+89
-124
lines changed

3 files changed

+89
-124
lines changed
 

‎src/quantization/paddlepaddle/parameters.py

Lines changed: 79 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
import os
22
import math
33
import random
4-
import functools
54
import numpy as np
65
import paddle
76
from pathlib import Path
87
from PIL import Image, ImageEnhance
98
from paddle.io import Dataset
109
from src.quantization.utils import ArgumentsParser
10+
import ast
11+
from paddle.io import DataLoader
12+
from paddleslim.quant import quant_post_static
1113

1214

1315
random.seed(0)
@@ -19,8 +21,6 @@
1921
THREAD = 16
2022
BUF_SIZE = 10240
2123

22-
DATA_DIR = r'D:\ws\dl-benchmark\imagenet'
23-
2424
img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
2525
img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
2626

@@ -143,78 +143,15 @@ def process_image(sample,
143143
return [img]
144144

145145

146-
def _reader_creator(file_list,
147-
mode,
148-
shuffle=False,
149-
color_jitter=False,
150-
rotate=False,
151-
data_dir=DATA_DIR,
152-
crop_size=DATA_DIM,
153-
resize_size=RESIZE_DIM,
154-
batch_size=1):
155-
def reader():
156-
try:
157-
with open(file_list) as flist:
158-
full_lines = [line.strip() for line in flist]
159-
if shuffle:
160-
np.random.shuffle(full_lines)
161-
lines = full_lines
162-
for line in lines:
163-
if mode == 'train' or mode == 'val':
164-
img_path, label = line.split()
165-
img_path = os.path.join(data_dir, img_path) + '.JPEG'
166-
yield img_path, int(label)
167-
elif mode == 'test':
168-
img_path = os.path.join(data_dir, line)
169-
yield [img_path]
170-
except Exception as e:
171-
print(f'Reader failed!\n{str(e)}')
172-
os._exit(1)
173-
174-
mapper = functools.partial(
175-
process_image,
176-
mode=mode,
177-
color_jitter=color_jitter,
178-
rotate=rotate,
179-
crop_size=crop_size,
180-
resize_size=resize_size)
181-
182-
return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE)
183-
184-
185-
def train(data_dir=DATA_DIR):
186-
file_list = os.path.join(data_dir, 'train_loc.txt')
187-
return _reader_creator(
188-
file_list,
189-
'train',
190-
shuffle=True,
191-
color_jitter=False,
192-
rotate=False,
193-
data_dir=data_dir)
194-
195-
196-
def val(data_dir=DATA_DIR):
197-
file_list = os.path.join(data_dir, 'val.txt')
198-
return _reader_creator(file_list, 'val', shuffle=False, data_dir=data_dir)
199-
200-
201-
def test(data_dir=DATA_DIR):
202-
file_list = os.path.join(data_dir, 'test.txt')
203-
return _reader_creator(file_list, 'test', shuffle=False, data_dir=data_dir)
204-
205-
206-
class ImageNetDataset(Dataset):
207-
def __init__(self,
208-
data_dir=DATA_DIR,
209-
mode='train',
210-
crop_size=DATA_DIM,
211-
resize_size=RESIZE_DIM):
212-
super(ImageNetDataset, self).__init__()
213-
self.data_dir = data_dir
214-
self.crop_size = crop_size
215-
self.resize_size = resize_size
146+
class PaddleDatasetReader(Dataset):
147+
def __init__(self, args, mode='train'):
148+
super(PaddleDatasetReader, self).__init__()
149+
self.data_dir = ast.literal_eval(args['Path'])
150+
self.crop_size = ast.literal_eval(args['CropResolution'])
151+
self.resize_size = ast.literal_eval(args['ImageResolution'])
216152
self.mode = mode
217-
self.dataset = list(Path(data_dir).glob('*'))
153+
self.batch_size = int(args['BatchSize'])
154+
self.dataset = list(Path(self.data_dir).glob('*'))
218155
random.shuffle(self.dataset)
219156
self.dataset_iter = iter(self.dataset)
220157

@@ -249,7 +186,50 @@ def __getitem__(self, index):
249186
return data
250187

251188
def __len__(self):
252-
return len(self.data)
189+
return len(self.dataset)
190+
191+
192+
class PaddleQuantizationProcess:
193+
def __init__(self, log, model_reader, dataset, quant_params):
194+
self.log = log
195+
self.model_reader = model_reader
196+
self.dataset = dataset
197+
self.quant_params = quant_params
198+
199+
def transform_fn(self):
200+
for data in self.dataset:
201+
yield [data.astype(np.float32)]
202+
203+
def quantization_tflite(self):
204+
paddle.enable_static()
205+
place = paddle.CPUPlace()
206+
exe = paddle.static.Executor(place)
207+
208+
data_loader = DataLoader(
209+
self.dataset,
210+
places=place,
211+
feed_list=[self.quant_params.image],
212+
drop_last=False,
213+
return_list=False,
214+
batch_size=self.dataset.batch_size,
215+
shuffle=False)
216+
217+
quant_post_static(
218+
executor=exe,
219+
model_dir=self.model_reader.model_dir,
220+
quantize_model_path=self.quant_params.save_dir,
221+
data_loader=data_loader,
222+
model_filename=self.model_reader.model_filename,
223+
params_filename=self.model_reader.params_filename,
224+
batch_size=self.dataset.batch_size,
225+
batch_nums=10,
226+
algo='avg',
227+
round_type='round',
228+
hist_percent=0.9999,
229+
is_full_quantize=False,
230+
bias_correction=False,
231+
onnx_format=False)
232+
253233

254234

255235
class PaddleModelReader(ArgumentsParser):
@@ -259,16 +239,33 @@ def __init__(self, log):
259239
def _get_arguments(self):
260240
self._log.info('Parsing model arguments.')
261241
self.path_prefix = self.args['PathPrefix']
262-
self._read_model()
242+
self.model_dir = self.args['ModelDir']
243+
self.model_filename = self.args['ModelFileName']
244+
self.params_filename = self.args['ParamsFileName']
263245

264246
def dict_for_iter_log(self):
265247
return {
266248
'Model path prefix': self.path_prefix,
267249
}
268250

269-
def _read_model(self):
270-
paddle.enable_static()
271-
place = paddle.CPUPlace()
272-
exe = paddle.static.Executor(place)
273-
self.inference_program, self.feed_target_names, self.fetch_targets = paddle.static.load_inference_model(
274-
self.path_prefix, exe)
251+
252+
class PaddleQuantParamReader(ArgumentsParser):
253+
def __init__(self, log):
254+
super().__init__(log)
255+
256+
def dict_for_iter_log(self):
257+
return {
258+
'InputShape': self.input_shape,
259+
'InputName': self.input_name,
260+
'SaveDir': self.save_dir,
261+
}
262+
263+
def _get_arguments(self):
264+
self.image_shape = ast.literal_eval(self.args['InputShape'])
265+
self.image = paddle.static.data(name=self.args['InputName'], shape=[None] + self.image_shape, dtype='float32')
266+
self.input_shape = self.args['InputShape']
267+
self.input_name = self.args['InputName']
268+
self.save_dir = self.args['SaveDir']
269+
270+
def _convert_to_list_of_tf_objects(self, keys, dictionary):
271+
return [dictionary[key] for key in keys]

‎src/quantization/paddlepaddle/quantization_paddlepaddle.py

Lines changed: 9 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,7 @@
44
from pathlib import Path
55
from ...utils.logger_conf import configure_logger # noqa: E402
66
from ...quantization.utils import ConfigParser # noqa: E402
7-
from paddleslim.quant import quant_post_static
8-
import paddle
9-
from paddle.io import DataLoader
10-
from parameters import PaddleModelReader, ImageNetDataset
11-
import ast
7+
from parameters import PaddleModelReader, PaddleDatasetReader, PaddleQuantizationProcess, PaddleQuantParamReader
128

139
sys.path.append(str(Path(__file__).resolve().parents[3]))
1410

@@ -31,47 +27,18 @@ def main():
3127
try:
3228
log.info(f'Parsing the configuration file {args.config}')
3329
parser = ConfigParser(args.config)
30+
3431
config = parser.parse()
3532
exit_code = 0
33+
quant_params = PaddleQuantParamReader(log)
34+
model_reader = PaddleModelReader(log)
3635
for model_quant_config in config:
3736
try:
38-
paddle.enable_static()
39-
place = paddle.CPUPlace()
40-
exe = paddle.static.Executor(place)
41-
val_dataset = (
42-
ImageNetDataset(mode='test',
43-
crop_size=ast.literal_eval(model_quant_config[1]['Dataset']['CropResolution']),
44-
resize_size=ast.literal_eval(model_quant_config[1]['Dataset']['ImageResolution']),
45-
data_dir=model_quant_config[1]['Dataset']['Path']))
46-
47-
image_shape = ast.literal_eval(model_quant_config[2]['Parameters']['InputShape'])
48-
image = paddle.static.data(
49-
name=model_quant_config[2]['Parameters']['InputName'], shape=[None] + image_shape, dtype='float32')
50-
51-
data_loader = DataLoader(
52-
val_dataset,
53-
places=place,
54-
feed_list=[image],
55-
drop_last=False,
56-
return_list=False,
57-
batch_size=32,
58-
shuffle=False)
59-
60-
quant_post_static(
61-
executor=exe,
62-
model_dir=model_quant_config[0]['Model']['ModelDir'],
63-
quantize_model_path=model_quant_config[2]['Parameters']['SaveDir'],
64-
data_loader=data_loader,
65-
model_filename=model_quant_config[0]['Model']['ModelFileName'],
66-
params_filename=model_quant_config[0]['Model']['ParamsFileName'],
67-
batch_size=int(model_quant_config[1]['DataSet']['BatchSize']),
68-
batch_nums=10,
69-
algo='avg',
70-
round_type='round',
71-
hist_percent=0.9999,
72-
is_full_quantize=False,
73-
bias_correction=False,
74-
onnx_format=False)
37+
data_reader = PaddleDatasetReader(model_quant_config[1]['Dataset'])
38+
model_reader.add_arguments(model_quant_config[0]['Model'])
39+
quant_params.add_arguments(model_quant_config[2]['QuantizationParameters'])
40+
proc = PaddleQuantizationProcess(log, model_reader, data_reader, quant_params)
41+
proc.quantization_tflite()
7542

7643
except Exception:
7744
log.error(traceback.format_exc())

‎tests/smoke_test/configs/quantization_models/resnet-50_PADDLEPADDLE.xml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
<Path>../test_images/classification_images</Path>
1414
<Mean>[123.675, 116.28, 103.53]</Mean>
1515
<Std>[58.395, 57.12, 57.375]</Std>
16+
<BatchSize>1</BatchSize>
1617
<ImageResolution>[224, 224]</ImageResolution>
1718
<ResizeResolution>[256, 256]</ResizeResolution>
1819
</Dataset>

0 commit comments

Comments
 (0)