From e9af8761937acdc00c39812c20eb2d94d8ad6896 Mon Sep 17 00:00:00 2001 From: Yury <yury.nahshan@intel.com> Date: Thu, 4 Apr 2019 15:02:34 +0300 Subject: [PATCH] lsq for weights --- distiller/policy.py | 6 +- distiller/quantization/learned_linear.py | 140 +++++++++++++++--- distiller/quantization/quantizer.py | 17 ++- .../quant_aware_train/lsq_4bit.yaml | 18 +-- 4 files changed, 139 insertions(+), 42 deletions(-) diff --git a/distiller/policy.py b/distiller/policy.py index f42219f..5dbda6c 100755 --- a/distiller/policy.py +++ b/distiller/policy.py @@ -218,7 +218,11 @@ class QuantizationPolicy(ScheduledTrainingPolicy): super(QuantizationPolicy, self).__init__() self.quantizer = quantizer self.quantizer.prepare_model() - self.quantizer.quantize_params() + # self.quantizer.quantize_params() + + def on_epoch_begin(self, model, zeros_mask_dict, meta): + if meta['current_epoch'] == 0: + self.quantizer.quantize_params() def on_minibatch_end(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict, optimizer): # After parameters update, quantize the parameters again diff --git a/distiller/quantization/learned_linear.py b/distiller/quantization/learned_linear.py index a9c58c4..bee164d 100644 --- a/distiller/quantization/learned_linear.py +++ b/distiller/quantization/learned_linear.py @@ -40,7 +40,7 @@ class LSQLinearQuantization(nn.Module): super(LSQLinearQuantization, self).__init__() self.size = size self.num_bits = num_bits - self.learned_scale = nn.Parameter(torch.tensor([(2**num_bits - 1)/3.])) # TODO: change to better initialization + self.learned_scale_activation = nn.Parameter(torch.tensor([(2 ** num_bits - 1) / 3.])) self.dequantize = dequantize self.inplace = inplace self.half_range = half_range @@ -62,20 +62,21 @@ class LSQLinearQuantization(nn.Module): half_range=self.half_range) _, clipped_max = clipper(input) - self.scale_init.data = (2**self.num_bits - 1) / (0.5*clipped_max + 0.5*current_max) + rho = 0.5 + self.scale_init.data = (2**self.num_bits - 1) / (rho*clipped_max + (1 - rho)*current_max) self.initialized = True # Assume relu with zero point = 0 # Quantize - input_q = self.learned_scale * input + input_q = self.learned_scale_activation * input # clamp and round input_q = torch.clamp(input_q, 0, 2**self.num_bits - 1) input_q = RoundSTE.apply(input_q) # dequantize - input_q = input_q / self.learned_scale + input_q = input_q / self.learned_scale_activation delta = input_q.detach() - input.detach() self.delta_mse.data = torch.norm(delta) / delta.numel() @@ -87,15 +88,90 @@ class LSQLinearQuantization(nn.Module): return '{0}(num_bits={1}, {2})'.format(self.__class__.__name__, self.num_bits, inplace_str) +class LSQParamsQuantization: + def __init__(self, per_channel=False): + self.per_channel = per_channel + + def __call__(self, param_fp, param_meta): + if self.per_channel: + out = self.lsq_quantize_param_per_channel(param_fp, param_meta) + else: + out = self.lsq_quantize_param(param_fp, param_meta) + + return out + + @staticmethod + def lsq_quantize_param_per_channel(param_fp, param_meta): + # return param_fp + scale = param_meta.module.learned_scale_weight.view(param_fp.shape[0], 1) + num_bits = param_meta.num_bits + + orig_shape = param_fp.shape + param_fp = param_fp.view(param_fp.shape[0], -1) + + # Quantize + param_q = scale * param_fp + + # clamp and round + lower = -2 ** (num_bits - 1) + upper = 2 ** (num_bits - 1) - 1 + param_q = torch.clamp(param_q, lower, upper) + param_q = RoundSTE.apply(param_q) + + # Dequantize + param_q = param_q / scale + + return param_q.view(orig_shape) + + @staticmethod + def lsq_quantize_param(param_fp, param_meta): + # return param_fp + scale = param_meta.module.learned_scale_weight + num_bits = param_meta.num_bits + + # Quantize + param_q = scale * param_fp + + # clamp and round + lower = -2 ** (num_bits - 1) + upper = 2 ** (num_bits - 1) - 1 + param_q = torch.clamp(param_q, lower, upper) + param_q = RoundSTE.apply(param_q) + + # Dequantize + param_q = param_q / scale + + return param_q + + @staticmethod + def initialize_scale(float_weight, num_bits, per_channel_wts): + if per_channel_wts: + float_weight = float_weight.view(float_weight.shape[0], -1) + max_ = torch.abs(float_weight.max(-1)[0]) + min_ = torch.abs(float_weight.min(-1)[0]) + rng = torch.max(min_, max_) + scale = (2**(num_bits - 1) - 1) / rng + else: + rng = torch.max(torch.abs(float_weight.max()), torch.abs(float_weight.min())) + scale = torch.tensor([(2**(num_bits - 1) - 1) / rng]) + + return scale + + class LSQQuatizer(Quantizer): def __init__(self, model, optimizer, bits_activations=32, bits_weights=32, bits_overrides=None, - quantize_bias=False, scale_decay=None, scale_lr=None): + quantize_bias=False, scale_act_decay=None, scale_act_lr=None, scale_w_decay=None, scale_w_lr=None, + per_channel_wts=False): super(LSQQuatizer, self).__init__(model, optimizer=optimizer, bits_activations=bits_activations, bits_weights=bits_weights, bits_overrides=bits_overrides, train_with_fp_copy=True, quantize_bias=quantize_bias) - self.scale_decay = scale_decay - self.scale_lr = scale_lr + self.scale_act_decay = scale_act_decay + self.scale_act_lr = scale_act_lr + self.scale_w_decay = scale_w_decay + self.scale_w_lr = scale_w_lr + self.per_channel_wts = per_channel_wts + self.initialized = False def relu_replace_fn(module, name, qbits_map): @@ -107,13 +183,17 @@ class LSQQuatizer(Quantizer): self.replacement_factory[nn.ReLU] = relu_replace_fn + self.param_quantization_fn = LSQParamsQuantization(per_channel=per_channel_wts) + def get_loger_stats(self, model, optimizer): stats_dict = OrderedDict() - stats_dict['global/LR'] = optimizer.param_groups[1]['lr'] - stats_dict['global/weight_decay'] = optimizer.param_groups[1]['weight_decay'] + stats_dict['global/scale_act_lr'] = optimizer.param_groups[1]['lr'] + stats_dict['global/scale_act_decay'] = optimizer.param_groups[1]['weight_decay'] + stats_dict['global/scale_w_lr'] = optimizer.param_groups[2]['lr'] + stats_dict['global/scale_w_decay'] = optimizer.param_groups[2]['weight_decay'] scale_params = [(n, p) for n, p in model.named_parameters() if 'learned_scale' in n] for name, param in scale_params: - stats_dict[name.replace('module.', '') + '/scale'] = param.item() + stats_dict[name.replace('module.', '') + '/scale'] = param.item() if param.numel() == 1 else param.mean() stats1 = ('Scale/', stats_dict) stats_dict = OrderedDict() @@ -140,25 +220,41 @@ class LSQQuatizer(Quantizer): return [stats1, stats3, stats4] def on_minibatch_end(self, epoch, train_step, steps_per_epoch, optimizer): + self.quantize_params() + if not self.initialized: - tract_scale = [(k, self.model.state_dict()[k]) for k in self.model.state_dict() if + scale_init_act = [(k, self.model.state_dict()[k]) for k in self.model.state_dict() if 'scale_init' in k] - scale_params = [(n, p) for n, p in self.model.named_parameters() if 'learned_scale' in n] - for n, p in scale_params: - l_name = n.replace('.learned_scale', '') - scale = [t for n, t in tract_scale if l_name in n][0] - learned_scale_param = [p for n, p in scale_params if l_name in n][0] + scale_params_act = [(n, p) for n, p in self.model.named_parameters() if 'learned_scale_activation' in n] + for n, p in scale_params_act: + l_name = n.replace('.learned_scale_activation', '') + scale = [t for n, t in scale_init_act if l_name in n][0] + learned_scale_param = [p for n, p in scale_params_act if l_name in n][0] learned_scale_param.data.copy_(scale) self.initialized = True def _get_updated_optimizer_params_groups(self): base_group = {'params': [param for name, param in self.model.named_parameters() if 'learned_scale' not in name]} - scale_group = {'params': [param for name, param in self.model.named_parameters() if 'learned_scale' in name]} + scale_act_group = {'params': [param for name, param in self.model.named_parameters() if 'learned_scale_activation' in name]} + scale_w_group = {'params': [param for name, param in self.model.named_parameters() if 'learned_scale_weight' in name]} + + if self.scale_act_lr is not None: + scale_act_group['lr'] = self.scale_act_lr + if self.scale_act_decay is not None: + scale_act_group['weight_decay'] = self.scale_act_decay + + if self.scale_w_lr is not None: + scale_w_group['lr'] = self.scale_w_lr + if self.scale_w_decay is not None: + scale_w_group['weight_decay'] = self.scale_w_decay + + return [base_group, scale_act_group, scale_w_group] - if self.scale_lr is not None: - scale_group['lr'] = self.scale_lr - if self.scale_decay is not None: - scale_group['weight_decay'] = self.scale_decay + def _prepare_model_impl(self): + super(LSQQuatizer, self)._prepare_model_impl() - return [base_group, scale_group] + for ptq in self.params_to_quantize: + m = ptq.module + m.learned_scale_weight = nn.Parameter(LSQParamsQuantization.initialize_scale( + ptq.module.float_weight, ptq.num_bits, self.per_channel_wts)) diff --git a/distiller/quantization/quantizer.py b/distiller/quantization/quantizer.py index c3814db..f7a2bf1 100644 --- a/distiller/quantization/quantizer.py +++ b/distiller/quantization/quantizer.py @@ -33,7 +33,7 @@ def has_bias(module): return hasattr(module, 'bias') and module.bias is not None -def hack_float_backup_parameter(module, name, num_bits, sat_mode): +def hack_float_backup_parameter(module, name, num_bits, sat_mode=None): try: data = dict(module.named_parameters())[name].data except KeyError: @@ -52,7 +52,7 @@ def hack_float_backup_parameter(module, name, num_bits, sat_mode): if not first: module.repr_mod += ' ; ' module.repr_mod += '{0} --> {1} bits'.format(name, num_bits) - if 'weight' in name and num_bits < 8: + if 'weight' in name and num_bits < 8 and sat_mode is not None: sat_mode_str = str(sat_mode).split('.')[1] if sat_mode else 'No' module.repr_mod += ', wts_sat_mode --> {0}'.format(sat_mode_str) @@ -165,6 +165,12 @@ class Quantizer(object): def prepare_model(self): self._prepare_model_impl() + # If an optimizer was passed, assume we need to update it + if self.optimizer: + optimizer_type = type(self.optimizer) + new_optimizer = optimizer_type(self._get_updated_optimizer_params_groups(), **self.optimizer.defaults) + self.optimizer.__setstate__({'param_groups': new_optimizer.param_groups}) + msglogger.info('Quantized model:\n\n{0}\n'.format(self.model)) def _prepare_model_impl(self): @@ -193,7 +199,7 @@ class Quantizer(object): fp_attr_name = param_name if self.train_with_fp_copy: # ugly way to pass wts_sat_mode - hack_float_backup_parameter(module, param_name, n_bits, self.wts_sat_mode) + hack_float_backup_parameter(module, param_name, n_bits) fp_attr_name = FP_BKP_PREFIX + param_name self.params_to_quantize.append(_ParamToQuant(module, module_name, fp_attr_name, param_name, n_bits)) @@ -201,11 +207,6 @@ class Quantizer(object): msglogger.info( "Parameter '{0}' will be quantized to {1} bits".format(param_full_name, n_bits)) - # If an optimizer was passed, assume we need to update it - if self.optimizer: - optimizer_type = type(self.optimizer) - new_optimizer = optimizer_type(self._get_updated_optimizer_params_groups(), **self.optimizer.defaults) - self.optimizer.__setstate__({'param_groups': new_optimizer.param_groups}) def _pre_process_container(self, container, prefix=''): # Iterate through model, insert quantization functions as appropriate diff --git a/examples/quantization/quant_aware_train/lsq_4bit.yaml b/examples/quantization/quant_aware_train/lsq_4bit.yaml index 96152b4..1ea5b89 100644 --- a/examples/quantization/quant_aware_train/lsq_4bit.yaml +++ b/examples/quantization/quant_aware_train/lsq_4bit.yaml @@ -2,27 +2,23 @@ quantizers: lsq_quantizer: class: LSQQuatizer bits_activations: 4 - bits_weights: null + bits_weights: 4 # num_bits_inputs: 8 mode: 'ASYMMETRIC_UNSIGNED' # Can try "SYMMETRIC" as well per_channel_wts: True quantize_inputs: False # scale_decay: 0.0001 - scale_lr: 0.1 - zero_point_lr: 0.001 + scale_act_lr: 0.1 + scale_w_lr: 0.0001 bits_overrides: conv1: - acts: null -# wts: 8 +# acts: 8 + wts: 8 relu: acts: 8 - bn1: - acts: null - .*\.bn1: - acts: null fc: - acts: 8 -# wts: 8 +# acts: 8 + wts: 8 policies: - quantizer: -- GitLab