Skip to content
Snippets Groups Projects
Commit e9af8761 authored by Yury's avatar Yury
Browse files

lsq for weights

parent 7f2a535b
Branches gste
No related tags found
No related merge requests found
...@@ -218,7 +218,11 @@ class QuantizationPolicy(ScheduledTrainingPolicy): ...@@ -218,7 +218,11 @@ class QuantizationPolicy(ScheduledTrainingPolicy):
super(QuantizationPolicy, self).__init__() super(QuantizationPolicy, self).__init__()
self.quantizer = quantizer self.quantizer = quantizer
self.quantizer.prepare_model() self.quantizer.prepare_model()
self.quantizer.quantize_params() # self.quantizer.quantize_params()
def on_epoch_begin(self, model, zeros_mask_dict, meta):
if meta['current_epoch'] == 0:
self.quantizer.quantize_params()
def on_minibatch_end(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict, optimizer): def on_minibatch_end(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict, optimizer):
# After parameters update, quantize the parameters again # After parameters update, quantize the parameters again
......
...@@ -40,7 +40,7 @@ class LSQLinearQuantization(nn.Module): ...@@ -40,7 +40,7 @@ class LSQLinearQuantization(nn.Module):
super(LSQLinearQuantization, self).__init__() super(LSQLinearQuantization, self).__init__()
self.size = size self.size = size
self.num_bits = num_bits self.num_bits = num_bits
self.learned_scale = nn.Parameter(torch.tensor([(2**num_bits - 1)/3.])) # TODO: change to better initialization self.learned_scale_activation = nn.Parameter(torch.tensor([(2 ** num_bits - 1) / 3.]))
self.dequantize = dequantize self.dequantize = dequantize
self.inplace = inplace self.inplace = inplace
self.half_range = half_range self.half_range = half_range
...@@ -62,20 +62,21 @@ class LSQLinearQuantization(nn.Module): ...@@ -62,20 +62,21 @@ class LSQLinearQuantization(nn.Module):
half_range=self.half_range) half_range=self.half_range)
_, clipped_max = clipper(input) _, clipped_max = clipper(input)
self.scale_init.data = (2**self.num_bits - 1) / (0.5*clipped_max + 0.5*current_max) rho = 0.5
self.scale_init.data = (2**self.num_bits - 1) / (rho*clipped_max + (1 - rho)*current_max)
self.initialized = True self.initialized = True
# Assume relu with zero point = 0 # Assume relu with zero point = 0
# Quantize # Quantize
input_q = self.learned_scale * input input_q = self.learned_scale_activation * input
# clamp and round # clamp and round
input_q = torch.clamp(input_q, 0, 2**self.num_bits - 1) input_q = torch.clamp(input_q, 0, 2**self.num_bits - 1)
input_q = RoundSTE.apply(input_q) input_q = RoundSTE.apply(input_q)
# dequantize # dequantize
input_q = input_q / self.learned_scale input_q = input_q / self.learned_scale_activation
delta = input_q.detach() - input.detach() delta = input_q.detach() - input.detach()
self.delta_mse.data = torch.norm(delta) / delta.numel() self.delta_mse.data = torch.norm(delta) / delta.numel()
...@@ -87,15 +88,90 @@ class LSQLinearQuantization(nn.Module): ...@@ -87,15 +88,90 @@ class LSQLinearQuantization(nn.Module):
return '{0}(num_bits={1}, {2})'.format(self.__class__.__name__, self.num_bits, inplace_str) return '{0}(num_bits={1}, {2})'.format(self.__class__.__name__, self.num_bits, inplace_str)
class LSQParamsQuantization:
def __init__(self, per_channel=False):
self.per_channel = per_channel
def __call__(self, param_fp, param_meta):
if self.per_channel:
out = self.lsq_quantize_param_per_channel(param_fp, param_meta)
else:
out = self.lsq_quantize_param(param_fp, param_meta)
return out
@staticmethod
def lsq_quantize_param_per_channel(param_fp, param_meta):
# return param_fp
scale = param_meta.module.learned_scale_weight.view(param_fp.shape[0], 1)
num_bits = param_meta.num_bits
orig_shape = param_fp.shape
param_fp = param_fp.view(param_fp.shape[0], -1)
# Quantize
param_q = scale * param_fp
# clamp and round
lower = -2 ** (num_bits - 1)
upper = 2 ** (num_bits - 1) - 1
param_q = torch.clamp(param_q, lower, upper)
param_q = RoundSTE.apply(param_q)
# Dequantize
param_q = param_q / scale
return param_q.view(orig_shape)
@staticmethod
def lsq_quantize_param(param_fp, param_meta):
# return param_fp
scale = param_meta.module.learned_scale_weight
num_bits = param_meta.num_bits
# Quantize
param_q = scale * param_fp
# clamp and round
lower = -2 ** (num_bits - 1)
upper = 2 ** (num_bits - 1) - 1
param_q = torch.clamp(param_q, lower, upper)
param_q = RoundSTE.apply(param_q)
# Dequantize
param_q = param_q / scale
return param_q
@staticmethod
def initialize_scale(float_weight, num_bits, per_channel_wts):
if per_channel_wts:
float_weight = float_weight.view(float_weight.shape[0], -1)
max_ = torch.abs(float_weight.max(-1)[0])
min_ = torch.abs(float_weight.min(-1)[0])
rng = torch.max(min_, max_)
scale = (2**(num_bits - 1) - 1) / rng
else:
rng = torch.max(torch.abs(float_weight.max()), torch.abs(float_weight.min()))
scale = torch.tensor([(2**(num_bits - 1) - 1) / rng])
return scale
class LSQQuatizer(Quantizer): class LSQQuatizer(Quantizer):
def __init__(self, model, optimizer, bits_activations=32, bits_weights=32, bits_overrides=None, def __init__(self, model, optimizer, bits_activations=32, bits_weights=32, bits_overrides=None,
quantize_bias=False, scale_decay=None, scale_lr=None): quantize_bias=False, scale_act_decay=None, scale_act_lr=None, scale_w_decay=None, scale_w_lr=None,
per_channel_wts=False):
super(LSQQuatizer, self).__init__(model, optimizer=optimizer, bits_activations=bits_activations, super(LSQQuatizer, self).__init__(model, optimizer=optimizer, bits_activations=bits_activations,
bits_weights=bits_weights, bits_overrides=bits_overrides, bits_weights=bits_weights, bits_overrides=bits_overrides,
train_with_fp_copy=True, quantize_bias=quantize_bias) train_with_fp_copy=True, quantize_bias=quantize_bias)
self.scale_decay = scale_decay self.scale_act_decay = scale_act_decay
self.scale_lr = scale_lr self.scale_act_lr = scale_act_lr
self.scale_w_decay = scale_w_decay
self.scale_w_lr = scale_w_lr
self.per_channel_wts = per_channel_wts
self.initialized = False self.initialized = False
def relu_replace_fn(module, name, qbits_map): def relu_replace_fn(module, name, qbits_map):
...@@ -107,13 +183,17 @@ class LSQQuatizer(Quantizer): ...@@ -107,13 +183,17 @@ class LSQQuatizer(Quantizer):
self.replacement_factory[nn.ReLU] = relu_replace_fn self.replacement_factory[nn.ReLU] = relu_replace_fn
self.param_quantization_fn = LSQParamsQuantization(per_channel=per_channel_wts)
def get_loger_stats(self, model, optimizer): def get_loger_stats(self, model, optimizer):
stats_dict = OrderedDict() stats_dict = OrderedDict()
stats_dict['global/LR'] = optimizer.param_groups[1]['lr'] stats_dict['global/scale_act_lr'] = optimizer.param_groups[1]['lr']
stats_dict['global/weight_decay'] = optimizer.param_groups[1]['weight_decay'] stats_dict['global/scale_act_decay'] = optimizer.param_groups[1]['weight_decay']
stats_dict['global/scale_w_lr'] = optimizer.param_groups[2]['lr']
stats_dict['global/scale_w_decay'] = optimizer.param_groups[2]['weight_decay']
scale_params = [(n, p) for n, p in model.named_parameters() if 'learned_scale' in n] scale_params = [(n, p) for n, p in model.named_parameters() if 'learned_scale' in n]
for name, param in scale_params: for name, param in scale_params:
stats_dict[name.replace('module.', '') + '/scale'] = param.item() stats_dict[name.replace('module.', '') + '/scale'] = param.item() if param.numel() == 1 else param.mean()
stats1 = ('Scale/', stats_dict) stats1 = ('Scale/', stats_dict)
stats_dict = OrderedDict() stats_dict = OrderedDict()
...@@ -140,25 +220,41 @@ class LSQQuatizer(Quantizer): ...@@ -140,25 +220,41 @@ class LSQQuatizer(Quantizer):
return [stats1, stats3, stats4] return [stats1, stats3, stats4]
def on_minibatch_end(self, epoch, train_step, steps_per_epoch, optimizer): def on_minibatch_end(self, epoch, train_step, steps_per_epoch, optimizer):
self.quantize_params()
if not self.initialized: if not self.initialized:
tract_scale = [(k, self.model.state_dict()[k]) for k in self.model.state_dict() if scale_init_act = [(k, self.model.state_dict()[k]) for k in self.model.state_dict() if
'scale_init' in k] 'scale_init' in k]
scale_params = [(n, p) for n, p in self.model.named_parameters() if 'learned_scale' in n] scale_params_act = [(n, p) for n, p in self.model.named_parameters() if 'learned_scale_activation' in n]
for n, p in scale_params: for n, p in scale_params_act:
l_name = n.replace('.learned_scale', '') l_name = n.replace('.learned_scale_activation', '')
scale = [t for n, t in tract_scale if l_name in n][0] scale = [t for n, t in scale_init_act if l_name in n][0]
learned_scale_param = [p for n, p in scale_params if l_name in n][0] learned_scale_param = [p for n, p in scale_params_act if l_name in n][0]
learned_scale_param.data.copy_(scale) learned_scale_param.data.copy_(scale)
self.initialized = True self.initialized = True
def _get_updated_optimizer_params_groups(self): def _get_updated_optimizer_params_groups(self):
base_group = {'params': [param for name, param in self.model.named_parameters() if 'learned_scale' not in name]} base_group = {'params': [param for name, param in self.model.named_parameters() if 'learned_scale' not in name]}
scale_group = {'params': [param for name, param in self.model.named_parameters() if 'learned_scale' in name]} scale_act_group = {'params': [param for name, param in self.model.named_parameters() if 'learned_scale_activation' in name]}
scale_w_group = {'params': [param for name, param in self.model.named_parameters() if 'learned_scale_weight' in name]}
if self.scale_act_lr is not None:
scale_act_group['lr'] = self.scale_act_lr
if self.scale_act_decay is not None:
scale_act_group['weight_decay'] = self.scale_act_decay
if self.scale_w_lr is not None:
scale_w_group['lr'] = self.scale_w_lr
if self.scale_w_decay is not None:
scale_w_group['weight_decay'] = self.scale_w_decay
return [base_group, scale_act_group, scale_w_group]
if self.scale_lr is not None: def _prepare_model_impl(self):
scale_group['lr'] = self.scale_lr super(LSQQuatizer, self)._prepare_model_impl()
if self.scale_decay is not None:
scale_group['weight_decay'] = self.scale_decay
return [base_group, scale_group] for ptq in self.params_to_quantize:
m = ptq.module
m.learned_scale_weight = nn.Parameter(LSQParamsQuantization.initialize_scale(
ptq.module.float_weight, ptq.num_bits, self.per_channel_wts))
...@@ -33,7 +33,7 @@ def has_bias(module): ...@@ -33,7 +33,7 @@ def has_bias(module):
return hasattr(module, 'bias') and module.bias is not None return hasattr(module, 'bias') and module.bias is not None
def hack_float_backup_parameter(module, name, num_bits, sat_mode): def hack_float_backup_parameter(module, name, num_bits, sat_mode=None):
try: try:
data = dict(module.named_parameters())[name].data data = dict(module.named_parameters())[name].data
except KeyError: except KeyError:
...@@ -52,7 +52,7 @@ def hack_float_backup_parameter(module, name, num_bits, sat_mode): ...@@ -52,7 +52,7 @@ def hack_float_backup_parameter(module, name, num_bits, sat_mode):
if not first: if not first:
module.repr_mod += ' ; ' module.repr_mod += ' ; '
module.repr_mod += '{0} --> {1} bits'.format(name, num_bits) module.repr_mod += '{0} --> {1} bits'.format(name, num_bits)
if 'weight' in name and num_bits < 8: if 'weight' in name and num_bits < 8 and sat_mode is not None:
sat_mode_str = str(sat_mode).split('.')[1] if sat_mode else 'No' sat_mode_str = str(sat_mode).split('.')[1] if sat_mode else 'No'
module.repr_mod += ', wts_sat_mode --> {0}'.format(sat_mode_str) module.repr_mod += ', wts_sat_mode --> {0}'.format(sat_mode_str)
...@@ -165,6 +165,12 @@ class Quantizer(object): ...@@ -165,6 +165,12 @@ class Quantizer(object):
def prepare_model(self): def prepare_model(self):
self._prepare_model_impl() self._prepare_model_impl()
# If an optimizer was passed, assume we need to update it
if self.optimizer:
optimizer_type = type(self.optimizer)
new_optimizer = optimizer_type(self._get_updated_optimizer_params_groups(), **self.optimizer.defaults)
self.optimizer.__setstate__({'param_groups': new_optimizer.param_groups})
msglogger.info('Quantized model:\n\n{0}\n'.format(self.model)) msglogger.info('Quantized model:\n\n{0}\n'.format(self.model))
def _prepare_model_impl(self): def _prepare_model_impl(self):
...@@ -193,7 +199,7 @@ class Quantizer(object): ...@@ -193,7 +199,7 @@ class Quantizer(object):
fp_attr_name = param_name fp_attr_name = param_name
if self.train_with_fp_copy: if self.train_with_fp_copy:
# ugly way to pass wts_sat_mode # ugly way to pass wts_sat_mode
hack_float_backup_parameter(module, param_name, n_bits, self.wts_sat_mode) hack_float_backup_parameter(module, param_name, n_bits)
fp_attr_name = FP_BKP_PREFIX + param_name fp_attr_name = FP_BKP_PREFIX + param_name
self.params_to_quantize.append(_ParamToQuant(module, module_name, fp_attr_name, param_name, n_bits)) self.params_to_quantize.append(_ParamToQuant(module, module_name, fp_attr_name, param_name, n_bits))
...@@ -201,11 +207,6 @@ class Quantizer(object): ...@@ -201,11 +207,6 @@ class Quantizer(object):
msglogger.info( msglogger.info(
"Parameter '{0}' will be quantized to {1} bits".format(param_full_name, n_bits)) "Parameter '{0}' will be quantized to {1} bits".format(param_full_name, n_bits))
# If an optimizer was passed, assume we need to update it
if self.optimizer:
optimizer_type = type(self.optimizer)
new_optimizer = optimizer_type(self._get_updated_optimizer_params_groups(), **self.optimizer.defaults)
self.optimizer.__setstate__({'param_groups': new_optimizer.param_groups})
def _pre_process_container(self, container, prefix=''): def _pre_process_container(self, container, prefix=''):
# Iterate through model, insert quantization functions as appropriate # Iterate through model, insert quantization functions as appropriate
......
...@@ -2,27 +2,23 @@ quantizers: ...@@ -2,27 +2,23 @@ quantizers:
lsq_quantizer: lsq_quantizer:
class: LSQQuatizer class: LSQQuatizer
bits_activations: 4 bits_activations: 4
bits_weights: null bits_weights: 4
# num_bits_inputs: 8 # num_bits_inputs: 8
mode: 'ASYMMETRIC_UNSIGNED' # Can try "SYMMETRIC" as well mode: 'ASYMMETRIC_UNSIGNED' # Can try "SYMMETRIC" as well
per_channel_wts: True per_channel_wts: True
quantize_inputs: False quantize_inputs: False
# scale_decay: 0.0001 # scale_decay: 0.0001
scale_lr: 0.1 scale_act_lr: 0.1
zero_point_lr: 0.001 scale_w_lr: 0.0001
bits_overrides: bits_overrides:
conv1: conv1:
acts: null # acts: 8
# wts: 8 wts: 8
relu: relu:
acts: 8 acts: 8
bn1:
acts: null
.*\.bn1:
acts: null
fc: fc:
acts: 8 # acts: 8
# wts: 8 wts: 8
policies: policies:
- quantizer: - quantizer:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment