diff --git a/predtuner/approxes/_copy.py b/predtuner/approxes/_copy.py new file mode 100644 index 0000000000000000000000000000000000000000..acb0f78186cb14108f859e5d8a41c6b95b4b02a3 --- /dev/null +++ b/predtuner/approxes/_copy.py @@ -0,0 +1,45 @@ +import copy +from typing import TypeVar + +from torch.nn import Module, Parameter + +T = TypeVar('T') + + +def module_only_deepcopy(obj: T, memo=None) -> T: + """Recursively copy but only modules, not the weights. + In the return value, all weights are still shared with those in `obj`.""" + + memo = {} + + def recursive_scan_parameters(obj_): + # Don't recurse down primitive types + if isinstance(obj_, (int, float, bool, complex, str, bytes)): + return + # Additionally share all buffers of Module. For example, this accounts for + # running_{mean|var} in BatchNorm. + if isinstance(obj_, Module): + buffers = obj_.__dict__.get('_buffers') + for buffer in buffers.values(): + memo[id(buffer)] = buffer + # Share all parameters. + if isinstance(obj_, Parameter): + memo[id(obj_)] = obj_ + # Walk down all other types. + elif isinstance(obj_, dict): + for k in obj_.keys(): + recursive_scan_parameters(k) + for v in obj_.values(): + recursive_scan_parameters(v) + elif isinstance(obj_, (list, tuple)): + for x in obj_: + recursive_scan_parameters(x) + elif hasattr(obj_, '__dict__'): + for x in obj_.__dict__.values(): + recursive_scan_parameters(x) + + # Populate `memo`, and then deepcopy with `memo` so that things in memo are not copied. + recursive_scan_parameters(obj) + # noinspection PyArgumentList + copied = copy.deepcopy(obj, memo) + return copied diff --git a/predtuner/approxes/approxes.py b/predtuner/approxes/approxes.py new file mode 100644 index 0000000000000000000000000000000000000000..217e25ea5b6ccfd9b692292598e517fd7af179d6 --- /dev/null +++ b/predtuner/approxes/approxes.py @@ -0,0 +1,372 @@ +"""Approximation techniques for torch.nn layers.""" +from typing import Iterable, List, Optional, Type + +import torch +from torch.nn import Conv2d, Linear, Module, Parameter + +from ._copy import module_only_deepcopy +from ..torchapp import TorchApproxKnob + + +def _interpolate_first_dim(tensor: torch.Tensor, interp_indices: Iterable[int]): + def tensor_at(idx_: int): + if idx_ in interp_indices: + raise IndexError + if idx_ < 0 or idx_ >= tensor.size()[0]: + return torch.zeros_like(tensor[0]) + return tensor[idx_] + + for idx in interp_indices: + if idx < 0 or idx >= tensor.size()[0]: + raise IndexError + elif idx == 0: # First row + tensor[idx] = tensor_at(1) + elif idx == tensor.size()[0] - 1: # Last row + tensor[idx] = tensor_at(idx - 1) + else: # Middle rows + tensor[idx] = (tensor_at(idx - 1) + tensor_at(idx + 1)) / 2.0 + return tensor + + +class PerforateConv2dStride(TorchApproxKnob): + r"""Simulation of strided perforated convolution for `torch.nn.Conv2d`. + + Perforated convolution skips computing some entries in the output and instead interpolates + these values, to reduce the number of float-ops needed to complete a convolution op. + In this implementation, selected rows or columns of the output are discarded and replaced + with linearly interpolated values from the neighboring rows or columns. Each channel is + considered independently. + This implementation gives the same output as actual perforated convolution but without the + performance benefit. + + Parameters + ---------- + direction_is_row : bool + If True, discard and interpolate rows, otherwise columns. + stride : int \in [2, +\infty) + Skip 1 row/column in the convolution kernel per `stride` elements. + offset : int \in [0, stride) + Skipped first row/column is `offset`. + + Attributes + ---------- + interp_axis : int :math:`\in \{2, 3\}` + The axis that will be perforated over. As the input is an NCHW tensor, if + `direction_is_row` then `interp_axis = 2`, otherwise `interp_axis = 3`. + stride : int :math:`\in [2, +\infty)` + Equal to parameter `stride`. + offset : int :math:`\in [0, stride)` + Equal to parameter `offset`. + """ + + def __init__( + self, + name: str, + direction_is_row: bool, + stride: int, + offset: int, + use_fp16: bool, + exp_speedup: float, + ): + super().__init__( + name, + direction_is_row=direction_is_row, + stride=stride, + offset=offset, + use_fp16=use_fp16, + exp_speedup=exp_speedup, + ) + assert stride >= 2 + assert 0 <= offset < stride + self.interp_axis = 2 if direction_is_row else 3 + self.stride = stride + self.offset = offset + self.fp16 = use_fp16 + self.exp_speedup = exp_speedup + + def is_applicable(self, op: Module) -> bool: + return isinstance(op, Conv2d) + + @property + def deterministic(self) -> bool: + return True + + @property + def expected_speedup(self) -> float: + return self.exp_speedup + + class PerforateConv2dStrideModule(Module): + def __init__(self, conv: Conv2d, approx: "PerforateConv2dStride"): + super().__init__() + self.conv = conv + self.approx = approx + if self.approx.fp16: + self.conv = self.conv.half() + + def conv_no_bias(self, x: torch.Tensor): + if self.conv.bias is None: + return self.conv(x) + bias = self.conv.bias + self.conv.bias = None + result = self.conv(x) + self.conv.bias = bias + return result + + def add_conv_bias(self, conv_output: torch.Tensor): + if self.conv.bias is None: + return conv_output + broadcast_bias = self.conv.bias.reshape(1, -1, 1, 1) + return conv_output + broadcast_bias + + def forward(self, x: torch.Tensor): + if self.approx.fp16: + x = x.half() + x = self.conv_no_bias(x) + assert x.dim() == 4 + # Put self.approx.interp_axis to first axis temporarily + x = x.transpose(0, self.approx.interp_axis) + interp_indices = torch.tensor( + range(self.approx.offset, x.size(0), self.approx.stride) + ) + x = _interpolate_first_dim(x, interp_indices) + # Putting axes back + x = x.transpose(0, self.approx.interp_axis) + x = self.add_conv_bias(x) + if self.approx.fp16: + assert x.dtype == torch.float16 + return x.float() + + def apply(self, module: Conv2d) -> PerforateConv2dStrideModule: + return self.PerforateConv2dStrideModule(module, self) + + +class Conv2dSampling(TorchApproxKnob): + r"""Simulation of sampled convolution for `torch.nn.Conv2d`. + + Skips some elements of the convolution kernel in a uniform, strided manner, + to reduce the amount of float-ops needed to compute each output entry. + This implementation gives the same output as actual sampled convolution but without the + performance benefit. + + Parameters + ---------- + skip_every: int + Skip 1 element in the convolution kernel per `skip_every` elements. + skip_offset : int :math:`\in [0, +\infty)` + Index of first element to be skipped. + For example, if `skip_every = 3` and `skip_offset = 1`, then indices skipped + will be [1, 4, 7, ...] + interp_rate : float + The weight will be compensated ("interpolated") with a ratio after skipping elements, + which is naturally equal to :math:`1 + (1 / (skip\_every - 1)`. + `interp_rate` modifies this rate to :math:`1 + (1 / (skip\_every - 1) \times interp\_rate`. + use_fp16 : bool + Whether to use fp16 weight/input or not. + """ + + def __init__( + self, + name: str, + skip_every: int, + skip_offset: int, + interp_rate: float, + use_fp16: bool, + exp_speedup: float, + ): + super().__init__( + name, + skip_every=skip_every, + skip_offset=skip_offset, + interp_rate=interp_rate, + use_fp16=use_fp16, + exp_speedup=exp_speedup, + ) + assert skip_every >= 2 and skip_offset >= 0 + self.skip_every = skip_every + self.skip_offset = skip_offset + self.interp_rate = interp_rate + self.fp16 = use_fp16 + self.exp_speedup = exp_speedup + + def is_applicable(self, op: Module) -> bool: + return isinstance(op, Conv2d) + + @property + def deterministic(self) -> bool: + return True + + @property + def expected_speedup(self) -> float: + return self.exp_speedup + + @staticmethod + def sample_conv_weight( + interp_rate: float, skip_every: int, skip_offset: int, weight: torch.Tensor + ): + r"""Samples (skips & interpolates) convolution kernel according to parameters. + + For a given `weight` tensor of shape `(C1, C2, H, W)`, sample each output channel + (on axis 0) independently. + Flatten each output channel tensor into 1 dim. + In normal cases, set elements at indices ``range(skip_offset, C_2 * H * W, skip_every)`` + to 0. + However, if `skip_every` == `h` == `w` == 3, we may end up skipping the same whole rows for + each input channel, which is undesirable. + Instead, increment the offset by 1 for each input channel. + Last, multiplies the kernel by the inverse ratio of elements dropped for an interpolation. + """ + if len(weight.shape) != 4: + raise ValueError("Conv2d weight should be 4-dimensional") + c1, c2, h, w = weight.shape + if skip_every == h == w == 3: + # Indices (0..h*w) to skip for each input channel + per_chan_skip_indices = [ + range((i_chan + skip_offset) % skip_every, h * w, skip_every) + for i_chan in range(c2) + ] + # Indices (0..c2*h*w) for each output channel, created by adding i*h*w for ith channel. + skip_indices = torch.tensor( + [ + x + i * h * w + for i, per_chan in enumerate(per_chan_skip_indices) + for x in per_chan + ] + ) + else: + # Indices (0..c2*h*w) to skip for each output channel + skip_indices = torch.arange(skip_offset, c2 * h * w, skip_every) + flat_weight = weight.reshape(c1, -1).clone() + flat_weight[:, skip_indices] = 0 + interp_rate = 1 + (1 / (skip_every - 1) * interp_rate) + flat_weight *= interp_rate + return flat_weight.reshape_as(weight) + + def apply(self, module: Conv2d) -> Conv2d: + # Only copy the submodules, weights are still shared. + copied = module_only_deepcopy(module) + # But only write over the weight of copied version, original weight is unchanged. + copied.weight = Parameter( + self.sample_conv_weight( + self.interp_rate, self.skip_every, self.skip_offset, copied.weight + ) + ) + return copied + + +def _quantize_uint8( + tensor: torch.Tensor, range_min: float, range_max: float +) -> torch.Tensor: + """Simulates quantization of `tensor` down to uint8, while still returning float values. + + In the returned tensor the data will NOT be in [0, 255] range, but only 256 unique float + value will exist. + """ + + quantize_range = 256 + input_range = range_max - range_min + mul = input_range / quantize_range + # Map tensor into [0, 256] range. + affined = (tensor - range_min) / mul + # Convert tensor to int and back to float so it will have + # 256 (actually 257!; following hpvm impl) unique float values [0, 256]. + # Then reverse affine it to the original range. + quanted = torch.floor(affined).to(torch.int).to(torch.float) + quanted_float = quanted * mul + range_min + # Clip tensor + return torch.clamp(quanted_float, range_min, range_max) + + +class PromiseSim(TorchApproxKnob): + """Simulates analog accelerator PROMISE. + + This hardware is proposed in "PROMISE: An End-to-End Design of a Programmable Mixed-Signal + Accelerator for Machine-Learning Algorithms." + """ + + scaling_values = [0.75, 0.64, 0.336, 0.21, 0.168, 0.14, 0.11, 0.0784, 0.005] + + def __init__(self, name: str, noise_level: int, exp_speedup: float): + super().__init__(name, noise_level=noise_level, exp_speedup=exp_speedup) + self.noise_level = noise_level + self.exp_speedup = exp_speedup + + def is_applicable(self, op: Module) -> bool: + return isinstance(op, (Conv2d, Linear)) + + @property + def deterministic(self) -> bool: + return False + + @property + def expected_speedup(self) -> float: + return self.exp_speedup + + def add_promise_noise(self, tensor: torch.Tensor): + scale = self.scaling_values[self.noise_level] + noise = torch.normal( + mean=0.0, std=scale, size=tensor.size(), device=tensor.device + ) + return noise * tensor + tensor + + class PromiseSimModule(Module): + def __init__(self, module: Conv2d, approx: "PromiseSim"): + super().__init__() + if not hasattr(module, "conv_ranges"): + raise ValueError( + f"Quantization range of conv2d layer {module} not found" + ) + self.input_r, weight_r, bias_r, self.output_r = module.conv_ranges + module.weight.data = _quantize_uint8(module.weight, *weight_r) + if module.bias is not None: + module.bias.data = _quantize_uint8(module.bias, *bias_r) + self.module = module + self.approx = approx + + def forward(self, input_: torch.Tensor) -> torch.Tensor: + # Quantize input, weight, bias (see __init__), and add noise to input. + input_ = _quantize_uint8(input_, *self.input_r) + input_ = self.approx.add_promise_noise(input_) + output = self.module(input_) + # Then again, quantize output. + return _quantize_uint8(output, *self.output_r) + + def apply(self, module: Conv2d, **kwargs) -> PromiseSimModule: + return self.PromiseSimModule(module, self) + + +class FP16Approx(TorchApproxKnob): + """Approximates by reducing precision of layer computation to float16.""" + + def __init__(self, name: str, exp_speedup: float): + super().__init__(name, exp_speedup=exp_speedup) + self.exp_speedup = exp_speedup + + def is_applicable(self, op: Module) -> bool: + return isinstance(op, (Conv2d, Linear)) + + @property + def deterministic(self) -> bool: + return True + + @property + def applicable_op_types(self) -> List[Type[Module]]: + return [Conv2d, Linear] + + def expected_speedup(self) -> float: + return self.exp_speedup + + def is_less_approx(self, other: TorchApproxKnob) -> Optional[bool]: + return None + + class FP16ApproxModule(Module): + def __init__(self, module: Module): + super().__init__() + self.module = module.half() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.module(x.half()) + assert x.dtype == torch.float16 + return x.float() + + def apply(self, module: Module) -> FP16ApproxModule: + return self.FP16ApproxModule(module)