From e9af8761937acdc00c39812c20eb2d94d8ad6896 Mon Sep 17 00:00:00 2001
From: Yury <yury.nahshan@intel.com>
Date: Thu, 4 Apr 2019 15:02:34 +0300
Subject: [PATCH] lsq for weights

---
 distiller/policy.py                           |   6 +-
 distiller/quantization/learned_linear.py      | 140 +++++++++++++++---
 distiller/quantization/quantizer.py           |  17 ++-
 .../quant_aware_train/lsq_4bit.yaml           |  18 +--
 4 files changed, 139 insertions(+), 42 deletions(-)

diff --git a/distiller/policy.py b/distiller/policy.py
index f42219f..5dbda6c 100755
--- a/distiller/policy.py
+++ b/distiller/policy.py
@@ -218,7 +218,11 @@ class QuantizationPolicy(ScheduledTrainingPolicy):
         super(QuantizationPolicy, self).__init__()
         self.quantizer = quantizer
         self.quantizer.prepare_model()
-        self.quantizer.quantize_params()
+        # self.quantizer.quantize_params()
+
+    def on_epoch_begin(self, model, zeros_mask_dict, meta):
+        if meta['current_epoch'] == 0:
+            self.quantizer.quantize_params()
 
     def on_minibatch_end(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict, optimizer):
         # After parameters update, quantize the parameters again
diff --git a/distiller/quantization/learned_linear.py b/distiller/quantization/learned_linear.py
index a9c58c4..bee164d 100644
--- a/distiller/quantization/learned_linear.py
+++ b/distiller/quantization/learned_linear.py
@@ -40,7 +40,7 @@ class LSQLinearQuantization(nn.Module):
         super(LSQLinearQuantization, self).__init__()
         self.size = size
         self.num_bits = num_bits
-        self.learned_scale = nn.Parameter(torch.tensor([(2**num_bits - 1)/3.]))  # TODO: change to better initialization
+        self.learned_scale_activation = nn.Parameter(torch.tensor([(2 ** num_bits - 1) / 3.]))
         self.dequantize = dequantize
         self.inplace = inplace
         self.half_range = half_range
@@ -62,20 +62,21 @@ class LSQLinearQuantization(nn.Module):
                                                half_range=self.half_range)
 
                 _, clipped_max = clipper(input)
-                self.scale_init.data = (2**self.num_bits - 1) / (0.5*clipped_max + 0.5*current_max)
+                rho = 0.5
+                self.scale_init.data = (2**self.num_bits - 1) / (rho*clipped_max + (1 - rho)*current_max)
                 self.initialized = True
 
         # Assume relu with zero point = 0
 
         # Quantize
-        input_q = self.learned_scale * input
+        input_q = self.learned_scale_activation * input
 
         # clamp and round
         input_q = torch.clamp(input_q, 0, 2**self.num_bits - 1)
         input_q = RoundSTE.apply(input_q)
 
         # dequantize
-        input_q = input_q / self.learned_scale
+        input_q = input_q / self.learned_scale_activation
 
         delta = input_q.detach() - input.detach()
         self.delta_mse.data = torch.norm(delta) / delta.numel()
@@ -87,15 +88,90 @@ class LSQLinearQuantization(nn.Module):
         return '{0}(num_bits={1}, {2})'.format(self.__class__.__name__, self.num_bits, inplace_str)
 
 
+class LSQParamsQuantization:
+    def __init__(self, per_channel=False):
+        self.per_channel = per_channel
+
+    def __call__(self, param_fp, param_meta):
+        if self.per_channel:
+            out = self.lsq_quantize_param_per_channel(param_fp, param_meta)
+        else:
+            out = self.lsq_quantize_param(param_fp, param_meta)
+
+        return out
+
+    @staticmethod
+    def lsq_quantize_param_per_channel(param_fp, param_meta):
+        # return param_fp
+        scale = param_meta.module.learned_scale_weight.view(param_fp.shape[0], 1)
+        num_bits = param_meta.num_bits
+
+        orig_shape = param_fp.shape
+        param_fp = param_fp.view(param_fp.shape[0], -1)
+
+        # Quantize
+        param_q = scale * param_fp
+
+        # clamp and round
+        lower = -2 ** (num_bits - 1)
+        upper = 2 ** (num_bits - 1) - 1
+        param_q = torch.clamp(param_q, lower, upper)
+        param_q = RoundSTE.apply(param_q)
+
+        # Dequantize
+        param_q = param_q / scale
+
+        return param_q.view(orig_shape)
+
+    @staticmethod
+    def lsq_quantize_param(param_fp, param_meta):
+        # return param_fp
+        scale = param_meta.module.learned_scale_weight
+        num_bits = param_meta.num_bits
+
+        # Quantize
+        param_q = scale * param_fp
+
+        # clamp and round
+        lower = -2 ** (num_bits - 1)
+        upper = 2 ** (num_bits - 1) - 1
+        param_q = torch.clamp(param_q, lower, upper)
+        param_q = RoundSTE.apply(param_q)
+
+        # Dequantize
+        param_q = param_q / scale
+
+        return param_q
+
+    @staticmethod
+    def initialize_scale(float_weight, num_bits, per_channel_wts):
+        if per_channel_wts:
+            float_weight = float_weight.view(float_weight.shape[0], -1)
+            max_ = torch.abs(float_weight.max(-1)[0])
+            min_ = torch.abs(float_weight.min(-1)[0])
+            rng = torch.max(min_, max_)
+            scale = (2**(num_bits - 1) - 1) / rng
+        else:
+            rng = torch.max(torch.abs(float_weight.max()), torch.abs(float_weight.min()))
+            scale = torch.tensor([(2**(num_bits - 1) - 1) / rng])
+
+        return scale
+
+
 class LSQQuatizer(Quantizer):
     def __init__(self, model, optimizer, bits_activations=32, bits_weights=32, bits_overrides=None,
-                 quantize_bias=False, scale_decay=None, scale_lr=None):
+                 quantize_bias=False, scale_act_decay=None, scale_act_lr=None, scale_w_decay=None, scale_w_lr=None,
+                 per_channel_wts=False):
         super(LSQQuatizer, self).__init__(model, optimizer=optimizer, bits_activations=bits_activations,
                                             bits_weights=bits_weights, bits_overrides=bits_overrides,
                                             train_with_fp_copy=True, quantize_bias=quantize_bias)
 
-        self.scale_decay = scale_decay
-        self.scale_lr = scale_lr
+        self.scale_act_decay = scale_act_decay
+        self.scale_act_lr = scale_act_lr
+        self.scale_w_decay = scale_w_decay
+        self.scale_w_lr = scale_w_lr
+        self.per_channel_wts = per_channel_wts
+
         self.initialized = False
 
         def relu_replace_fn(module, name, qbits_map):
@@ -107,13 +183,17 @@ class LSQQuatizer(Quantizer):
 
         self.replacement_factory[nn.ReLU] = relu_replace_fn
 
+        self.param_quantization_fn = LSQParamsQuantization(per_channel=per_channel_wts)
+
     def get_loger_stats(self, model, optimizer):
         stats_dict = OrderedDict()
-        stats_dict['global/LR'] = optimizer.param_groups[1]['lr']
-        stats_dict['global/weight_decay'] = optimizer.param_groups[1]['weight_decay']
+        stats_dict['global/scale_act_lr'] = optimizer.param_groups[1]['lr']
+        stats_dict['global/scale_act_decay'] = optimizer.param_groups[1]['weight_decay']
+        stats_dict['global/scale_w_lr'] = optimizer.param_groups[2]['lr']
+        stats_dict['global/scale_w_decay'] = optimizer.param_groups[2]['weight_decay']
         scale_params = [(n, p) for n, p in model.named_parameters() if 'learned_scale' in n]
         for name, param in scale_params:
-            stats_dict[name.replace('module.', '') + '/scale'] = param.item()
+            stats_dict[name.replace('module.', '') + '/scale'] = param.item() if param.numel() == 1 else param.mean()
         stats1 = ('Scale/', stats_dict)
 
         stats_dict = OrderedDict()
@@ -140,25 +220,41 @@ class LSQQuatizer(Quantizer):
         return [stats1, stats3, stats4]
 
     def on_minibatch_end(self, epoch, train_step, steps_per_epoch, optimizer):
+        self.quantize_params()
+
         if not self.initialized:
-            tract_scale = [(k, self.model.state_dict()[k]) for k in self.model.state_dict() if
+            scale_init_act = [(k, self.model.state_dict()[k]) for k in self.model.state_dict() if
                          'scale_init' in k]
-            scale_params = [(n, p) for n, p in self.model.named_parameters() if 'learned_scale' in n]
-            for n, p in scale_params:
-                l_name = n.replace('.learned_scale', '')
-                scale = [t for n, t in tract_scale if l_name in n][0]
-                learned_scale_param = [p for n, p in scale_params if l_name in n][0]
+            scale_params_act = [(n, p) for n, p in self.model.named_parameters() if 'learned_scale_activation' in n]
+            for n, p in scale_params_act:
+                l_name = n.replace('.learned_scale_activation', '')
+                scale = [t for n, t in scale_init_act if l_name in n][0]
+                learned_scale_param = [p for n, p in scale_params_act if l_name in n][0]
                 learned_scale_param.data.copy_(scale)
 
             self.initialized = True
 
     def _get_updated_optimizer_params_groups(self):
         base_group = {'params': [param for name, param in self.model.named_parameters() if 'learned_scale' not in name]}
-        scale_group = {'params': [param for name, param in self.model.named_parameters() if 'learned_scale' in name]}
+        scale_act_group = {'params': [param for name, param in self.model.named_parameters() if 'learned_scale_activation' in name]}
+        scale_w_group = {'params': [param for name, param in self.model.named_parameters() if 'learned_scale_weight' in name]}
+
+        if self.scale_act_lr is not None:
+            scale_act_group['lr'] = self.scale_act_lr
+        if self.scale_act_decay is not None:
+            scale_act_group['weight_decay'] = self.scale_act_decay
+
+        if self.scale_w_lr is not None:
+            scale_w_group['lr'] = self.scale_w_lr
+        if self.scale_w_decay is not None:
+            scale_w_group['weight_decay'] = self.scale_w_decay
+
+        return [base_group, scale_act_group, scale_w_group]
 
-        if self.scale_lr is not None:
-            scale_group['lr'] = self.scale_lr
-        if self.scale_decay is not None:
-            scale_group['weight_decay'] = self.scale_decay
+    def _prepare_model_impl(self):
+        super(LSQQuatizer, self)._prepare_model_impl()
 
-        return [base_group, scale_group]
+        for ptq in self.params_to_quantize:
+            m = ptq.module
+            m.learned_scale_weight = nn.Parameter(LSQParamsQuantization.initialize_scale(
+                ptq.module.float_weight, ptq.num_bits, self.per_channel_wts))
diff --git a/distiller/quantization/quantizer.py b/distiller/quantization/quantizer.py
index c3814db..f7a2bf1 100644
--- a/distiller/quantization/quantizer.py
+++ b/distiller/quantization/quantizer.py
@@ -33,7 +33,7 @@ def has_bias(module):
     return hasattr(module, 'bias') and module.bias is not None
 
 
-def hack_float_backup_parameter(module, name, num_bits, sat_mode):
+def hack_float_backup_parameter(module, name, num_bits, sat_mode=None):
     try:
         data = dict(module.named_parameters())[name].data
     except KeyError:
@@ -52,7 +52,7 @@ def hack_float_backup_parameter(module, name, num_bits, sat_mode):
     if not first:
         module.repr_mod += ' ; '
     module.repr_mod += '{0} --> {1} bits'.format(name, num_bits)
-    if 'weight' in name and num_bits < 8:
+    if 'weight' in name and num_bits < 8 and sat_mode is not None:
         sat_mode_str = str(sat_mode).split('.')[1] if sat_mode else 'No'
         module.repr_mod += ', wts_sat_mode --> {0}'.format(sat_mode_str)
 
@@ -165,6 +165,12 @@ class Quantizer(object):
     def prepare_model(self):
         self._prepare_model_impl()
 
+        # If an optimizer was passed, assume we need to update it
+        if self.optimizer:
+            optimizer_type = type(self.optimizer)
+            new_optimizer = optimizer_type(self._get_updated_optimizer_params_groups(), **self.optimizer.defaults)
+            self.optimizer.__setstate__({'param_groups': new_optimizer.param_groups})
+
         msglogger.info('Quantized model:\n\n{0}\n'.format(self.model))
 
     def _prepare_model_impl(self):
@@ -193,7 +199,7 @@ class Quantizer(object):
                 fp_attr_name = param_name
                 if self.train_with_fp_copy:
                     # ugly way to pass wts_sat_mode
-                    hack_float_backup_parameter(module, param_name, n_bits, self.wts_sat_mode)
+                    hack_float_backup_parameter(module, param_name, n_bits)
                     fp_attr_name = FP_BKP_PREFIX + param_name
                 self.params_to_quantize.append(_ParamToQuant(module, module_name, fp_attr_name, param_name, n_bits))
 
@@ -201,11 +207,6 @@ class Quantizer(object):
                 msglogger.info(
                     "Parameter '{0}' will be quantized to {1} bits".format(param_full_name, n_bits))
 
-        # If an optimizer was passed, assume we need to update it
-        if self.optimizer:
-            optimizer_type = type(self.optimizer)
-            new_optimizer = optimizer_type(self._get_updated_optimizer_params_groups(), **self.optimizer.defaults)
-            self.optimizer.__setstate__({'param_groups': new_optimizer.param_groups})
 
     def _pre_process_container(self, container, prefix=''):
         # Iterate through model, insert quantization functions as appropriate
diff --git a/examples/quantization/quant_aware_train/lsq_4bit.yaml b/examples/quantization/quant_aware_train/lsq_4bit.yaml
index 96152b4..1ea5b89 100644
--- a/examples/quantization/quant_aware_train/lsq_4bit.yaml
+++ b/examples/quantization/quant_aware_train/lsq_4bit.yaml
@@ -2,27 +2,23 @@ quantizers:
   lsq_quantizer:
     class: LSQQuatizer
     bits_activations: 4
-    bits_weights: null
+    bits_weights: 4
 #    num_bits_inputs: 8
     mode: 'ASYMMETRIC_UNSIGNED'  # Can try "SYMMETRIC" as well
     per_channel_wts: True
     quantize_inputs: False
 #    scale_decay: 0.0001
-    scale_lr: 0.1
-    zero_point_lr: 0.001
+    scale_act_lr: 0.1
+    scale_w_lr: 0.0001
     bits_overrides:
       conv1:
-        acts: null
-#        wts: 8
+#        acts: 8
+        wts: 8
       relu:
         acts: 8
-      bn1:
-        acts: null
-      .*\.bn1:
-        acts: null
       fc:
-        acts: 8
-#        wts: 8
+#        acts: 8
+        wts: 8
 
 policies:
     - quantizer:
-- 
GitLab