diff --git a/theanompi/lib/exchanger.py b/theanompi/lib/exchanger.py index a544478..9420e08 100644 --- a/theanompi/lib/exchanger.py +++ b/theanompi/lib/exchanger.py @@ -44,7 +44,7 @@ def remove_BN_params(params): class BSP_Exchanger(object): ''' - model parameter exchanger during BSP weight exchanging + model parameter/gradient exchanger based on BSP ''' def __init__(self, comm, gpucomm, exch_strategy, sync_type, ctx, model): @@ -54,54 +54,38 @@ def __init__(self, comm, gpucomm, exch_strategy, sync_type, ctx, model): self.size = comm.size + self.sync_type = sync_type + self.exch_strategy = exch_strategy - self.sync_type = sync_type - self.param_list = remove_BN_params(model.params) - if self.sync_type == 'cdd': + self.vels = model.vels self.vels2 = model.vels2 if self.sync_type == 'cdd' and self.exch_strategy == 'ar': from theanompi.lib.exchanger_strategy import Exch_allreduce - self.exch = Exch_allreduce(self.comm, avg=False) + self.exch = Exch_allreduce(self.comm) self.exch.prepare(self.vels, self.vels2) elif self.sync_type == 'cdd' and self.exch_strategy == 'nccl32': from theanompi.lib.exchanger_strategy import Exch_nccl32 - self.exch = Exch_nccl32(intercomm=self.comm, intracomm=self.gpucomm, avg=False) + self.exch = Exch_nccl32(intercomm=self.comm, intracomm=self.gpucomm) self.exch.prepare(self.ctx, self.vels, self.vels2) elif self.sync_type == 'cdd' and self.exch_strategy == 'nccl16': - from theanompi.lib.exchanger_strategy import Exch_nccl16 - self.exch = Exch_nccl16(intercomm=self.comm, intracomm=self.gpucomm, avg=False) - self.exch.prepare(self.ctx, self.vels, self.vels2) - - elif self.sync_type == 'avg' and self.exch_strategy == 'ar': - - from theanompi.lib.exchanger_strategy import Exch_allreduce - self.exch = Exch_allreduce(self.comm) - self.exch.prepare(self.param_list) - - elif self.sync_type == 'avg' and self.exch_strategy == 'nccl32': - - from theanompi.lib.exchanger_strategy import Exch_nccl32 - self.exch = Exch_nccl32(intercomm=self.comm, intracomm=self.gpucomm) - self.exch.prepare(self.ctx, self.param_list) - - elif self.sync_type == 'avg' and self.exch_strategy == 'nccl16': - from theanompi.lib.exchanger_strategy import Exch_nccl16 self.exch = Exch_nccl16(intercomm=self.comm, intracomm=self.gpucomm) - self.exch.prepare(self.ctx, self.param_list) + self.exch.prepare(self.ctx, self.vels, self.vels2) elif self.sync_type == 'swap' and self.exch_strategy == 'nccl32': + self.param_list = remove_BN_params(model.params) + from theanompi.lib.exchanger_strategy import Exch_swap self.exch = Exch_swap(intercomm=self.comm) self.exch.prepare(self.ctx, self.param_list) diff --git a/theanompi/lib/exchanger_strategy.py b/theanompi/lib/exchanger_strategy.py index 8ec04bd..b8ef5a5 100644 --- a/theanompi/lib/exchanger_strategy.py +++ b/theanompi/lib/exchanger_strategy.py @@ -29,7 +29,7 @@ class Exch_allreduce(Exch_strategy): paramter transfer passing host memory ''' - def __init__(self, comm, avg=True): + def __init__(self, comm, avg=False): Exch_strategy.__init__(self) self.comm = comm @@ -73,7 +73,12 @@ def exchange(self): class Exch_nccl32(Exch_strategy): - def __init__(self, intercomm, intracomm, avg=True): + + ''' + Single Node reduction + ''' + + def __init__(self, intercomm, intracomm, avg=False): Exch_strategy.__init__(self) self.intercomm = intercomm @@ -122,7 +127,12 @@ def exchange(self): self.intracomm.all_reduce(source, '+', dest) class Exch_nccl16(Exch_strategy): - def __init__(self, intercomm, intracomm, avg=True): + + ''' + Single Node reduction (half precision) + ''' + + def __init__(self, intercomm, intracomm, avg=False): Exch_strategy.__init__(self) self.intercomm = intercomm diff --git a/theanompi/lib/opt.py b/theanompi/lib/opt.py index 21d8776..a563b51 100644 --- a/theanompi/lib/opt.py +++ b/theanompi/lib/opt.py @@ -1,65 +1,29 @@ -def pre_model_iter_fn(model, sync_type, f_train=True, f_val=True): +def pre_model_iter_fn(model, k=1, f_train=True, f_val=True): # to make sure model compiles necessary functions (get_vels() and descent() for cdd, or train() for avg) and allocate necessary extra param memory (vels,vels2 for cdd, or nothing for avg) # allocate supporting params for this worker type if f_train: - - if sync_type == 'cdd': - - import theano - - model.vels = [theano.shared(param_i.get_value() * 0.) - for param_i in model.params] - - model.vels2 = [theano.shared(param_i.get_value() * 0.) - for param_i in model.params] - - updates_v, updates_dv = prepare_update_dict(model, sync_type='cdd') - updates_v=fix_update_bcasts(dict(updates_v)) - updates_dv=fix_update_bcasts(dict(updates_dv)) - - get_vel_args = {"inputs":[model.subb_ind], "outputs":[model.cost,model.error], "updates":updates_v, \ - "givens":[(model.x, model.shared_x_slice), - (model.y, model.shared_y_slice), - (model.lr, model.shared_lr)]} - - descent_vel_args = {"inputs":[], "outputs":[], "updates":updates_dv} - - model.compile_train(get_vel_args, descent_vel_args) # needs compile model before para_load_init() # 2 (local to worker type) - - model.get_vel, model.descent_vel = model.compiled_train_fn_list - + updates_v, updates_dv = prepare_update_dict(model, k=k) - else: # avg or other sync types - - import theano - - model.vels = [theano.shared(param_i.get_value() * 0.) - for param_i in model.params] - - model.vels2 = [theano.shared(param_i.get_value() * 0.) - for param_i in model.params] - - updates_w, = prepare_update_dict(model, sync_type='avg') - + updates_v=fix_update_bcasts(dict(updates_v)) + updates_dv=fix_update_bcasts(dict(updates_dv)) - updates_w=fix_update_bcasts(dict(updates_w)) - - - train_args = {"inputs":[model.subb_ind], "outputs": [model.cost,model.error], "updates": updates_w, \ - "givens": [(model.x, model.shared_x_slice), - (model.y, model.shared_y_slice), - (model.lr, model.shared_lr)]} - - model.compile_train(train_args) - - model.train_fn , = model.compiled_train_fn_list + get_vel_args = {"inputs":[model.subb_ind], "outputs":[model.cost,model.error], "updates":updates_v, \ + "givens":[(model.x, model.shared_x_slice), + (model.y, model.shared_y_slice), + (model.lr, model.shared_lr)]} + + descent_vel_args = {"inputs":[], "outputs":[], "updates":updates_dv} + + model.compile_train(get_vel_args, descent_vel_args) # needs compile model before para_load_init() # 2 (local to worker type) + + model.get_vel, model.descent_vel = model.compiled_train_fn_list - model.train_iter_fn = choose_iter_fn(model, sync_type) + model.train_iter_fn = choose_iter_fn(model) if f_val: @@ -74,42 +38,29 @@ def fix_update_bcasts(updates): updates[param] = T.patternbroadcast(update, param.broadcastable) return updates -def choose_iter_fn(model, sync_type): +def choose_iter_fn(model): - if sync_type == 'cdd': - - def cdd_iter_fn(subb_ind): - model.descent_vel() - cost, error = model.get_vel(subb_ind) - return cost, error - - return cdd_iter_fn - - elif sync_type == 'avg': - - return model.train_fn - -def prepare_update_dict(model, sync_type, clip=False): + # TODO maybe not be correct to perform step3 step1 -> step2 + + def cdd_iter_fn(subb_ind): + model.descent_vel() + cost, error = model.get_vel(subb_ind) + return cost, error + return cdd_iter_fn + +def prepare_update_dict(model, k=1): if model.use_momentum: - updates_w, updates_v, updates_dv = BSP_MSGD(model, model.use_nesterov_momentum,sync_type, clip) + updates_v, updates_dv = BSP_MSGD(model, model.use_nesterov_momentum, k=k) else: - updates_w, updates_v, updates_dv = BSP_SGD(model, sync_type, clip) + updates_v, updates_dv = BSP_SGD(model, k=k) - if sync_type == 'cdd': - update_dict = [updates_v, updates_dv] - - elif sync_type == 'avg': - - update_dict = [updates_w] - - - return update_dict + return updates_v, updates_dv def _clip_paramlist(param_list, scale=10): @@ -121,20 +72,21 @@ def _clip_paramlist(param_list, scale=10): return res -def BSP_MSGD(model, use_nesterov_momentum,sync_type, clip): +def BSP_MSGD(model, use_nesterov_momentum): params, grads, weight_types = model.params, model.grads, model.weight_types - if clip==True: - grads=_clip_paramlist(grads) - - vels, vels2 = model.vels, model.vels2 + import theano + + model.vels = [theano.shared(param_i.get_value() * 0.) + for param_i in model.params] + + model.vels2 = [theano.shared(param_i.get_value() * 0.) + for param_i in model.params] lr = model.lr #shared_lr #T.scalar('lr') # symbolic learning rate mu = model.mu # def: 0.9 # momentum eta = model.eta #0.0002 # weight decay - - updates_w = [] # for avg updates_v = [] # for cdd updates_dv = [] # for cdd @@ -159,83 +111,60 @@ def BSP_MSGD(model, use_nesterov_momentum,sync_type, clip): vel_i_next = mu ** 2 * vels[k] - (1 + mu) * real_lr * real_grad else: vel_i_next = mu * vels[k] - real_lr * real_grad - - if sync_type == 'cdd': - updates_v.append((vels[k], vel_i_next)) - updates_dv.append((param_i, param_i + vels2[k])) - - elif sync_type == 'avg': - - updates_w.append((vels[k], vel_i_next)) - updates_w.append((param_i, param_i + vel_i_next)) + updates_v.append((vels[k], vel_i_next)) + updates_dv.append((param_i, param_i + vels2[k])) k=k+1 - return updates_w, updates_v, updates_dv + return updates_v, updates_dv - -def BSP_SGD(model,sync_type, clip): +def BSP_SGD(model, k=1): params, grads, weight_types = model.params, model.grads, model.weight_types - - if clip==True: - grads=_clip_paramlist(grads) - vels, vels2 = model.vels, model.vels2 + import theano + + model.vels=[] + model.vels2=[] lr = model.lr #shared_lr #T.scalar('lr') # symbolic learning rate - mu = model.mu # def: 0.9 # momentum eta = model.eta #0.0002 # weight decay - - updates_w = [] # for avg - updates_v = [] # for cdd - updates_dv = [] # for cdd + updates_pre_g_aggre = [] # pre gradient aggregation + updates_post_g_aggre = [] # post gradient aggregation - assert len(weight_types) == len(params) - - - k=0 - - for param_i, grad_i, weight_type in \ - zip(params, grads, weight_types): + for ind, (param_i, grad_i, weight_type) in enumerate( + zip(params, grads, weight_types)): - if weight_type == 'W': - if sync_type == 'cdd': - - update = - lr * grad_i - eta * lr * param_i - - elif sync_type == 'avg': - - update = param_i - lr * grad_i - eta * lr * param_i + update = - lr * grad_i - eta * lr * param_i elif weight_type == 'b': - if sync_type == 'cdd': + update = - 2 * lr * grad_i + + if param_i.name in ['gamma', 'beta']: # explicitly not exchanging BN parameters - update = - 2 * lr * grad_i - - elif sync_type == 'avg': - - update = param_i - 2 * lr * grad_i - - if sync_type == 'cdd': + updates_pre_g_aggre.append((param_i, param_i + update)) # step 3: update local param with vels2 - updates_v.append((vels[k], update)) - updates_dv.append((param_i, param_i + vels2[k])) + else: - elif sync_type == 'avg': + tmp1=theano.shared(param_i.get_value() * 0.) + tmp2=theano.shared(param_i.get_value() * 0.) - # updates_w.append((vel_i, - 2 * lr * grad_i)) - updates_w.append((param_i, update)) + updates_pre_g_aggre.append((tmp1, update/float(k))) # step 1: process per-worker gradient into vels + # step 2 (during exchanging): allreduce per-worker gradients into vels2 + updates_post_g_aggre.append((param_i, param_i + tmp2)) # step 3: update local param with vels2 + model.vels.append(tmp1) + model.vels2.append(tmp2) - k=k+1 + # in practice BSP: + # training (step3-> step1) - > comm (step 2) - return updates_w, updates_v, updates_dv + return updates_pre_g_aggre, updates_post_g_aggre def MSGD(model, use_nesterov_momentum,sync_type, clip): diff --git a/theanompi/models/lasagne_model_zoo/resnet50.py b/theanompi/models/lasagne_model_zoo/resnet50.py index a38c58f..759a767 100644 --- a/theanompi/models/lasagne_model_zoo/resnet50.py +++ b/theanompi/models/lasagne_model_zoo/resnet50.py @@ -345,7 +345,7 @@ def build_model(self): self.y = T.lvector('y') self.lr = T.scalar('lr') - net = build_model_resnet152(input_shape=(None, 3, 224, 224)) + net = build_model_resnet50(input_shape=(None, 3, 224, 224)) self.output_layer = net['prob'] @@ -414,7 +414,7 @@ def compile_iter_fns(self, sync_type): from theanompi.lib.opt import pre_model_iter_fn - pre_model_iter_fn(self, sync_type=sync_type) + pre_model_iter_fn(self, self.size) if self.verbose: print('Compile time: %.3f s' % (time.time()-start))