Skip to content
Snippets Groups Projects
Commit 5251c4d5 authored by Yifan Zhao's avatar Yifan Zhao
Browse files

Added implementation for default approximations

parent f641a1b9
No related branches found
No related tags found
No related merge requests found
import copy
from typing import TypeVar
from torch.nn import Module, Parameter
T = TypeVar('T')
def module_only_deepcopy(obj: T, memo=None) -> T:
"""Recursively copy but only modules, not the weights.
In the return value, all weights are still shared with those in `obj`."""
memo = {}
def recursive_scan_parameters(obj_):
# Don't recurse down primitive types
if isinstance(obj_, (int, float, bool, complex, str, bytes)):
return
# Additionally share all buffers of Module. For example, this accounts for
# running_{mean|var} in BatchNorm.
if isinstance(obj_, Module):
buffers = obj_.__dict__.get('_buffers')
for buffer in buffers.values():
memo[id(buffer)] = buffer
# Share all parameters.
if isinstance(obj_, Parameter):
memo[id(obj_)] = obj_
# Walk down all other types.
elif isinstance(obj_, dict):
for k in obj_.keys():
recursive_scan_parameters(k)
for v in obj_.values():
recursive_scan_parameters(v)
elif isinstance(obj_, (list, tuple)):
for x in obj_:
recursive_scan_parameters(x)
elif hasattr(obj_, '__dict__'):
for x in obj_.__dict__.values():
recursive_scan_parameters(x)
# Populate `memo`, and then deepcopy with `memo` so that things in memo are not copied.
recursive_scan_parameters(obj)
# noinspection PyArgumentList
copied = copy.deepcopy(obj, memo)
return copied
"""Approximation techniques for torch.nn layers."""
from typing import Iterable, List, Optional, Type
import torch
from torch.nn import Conv2d, Linear, Module, Parameter
from ._copy import module_only_deepcopy
from ..torchapp import TorchApproxKnob
def _interpolate_first_dim(tensor: torch.Tensor, interp_indices: Iterable[int]):
def tensor_at(idx_: int):
if idx_ in interp_indices:
raise IndexError
if idx_ < 0 or idx_ >= tensor.size()[0]:
return torch.zeros_like(tensor[0])
return tensor[idx_]
for idx in interp_indices:
if idx < 0 or idx >= tensor.size()[0]:
raise IndexError
elif idx == 0: # First row
tensor[idx] = tensor_at(1)
elif idx == tensor.size()[0] - 1: # Last row
tensor[idx] = tensor_at(idx - 1)
else: # Middle rows
tensor[idx] = (tensor_at(idx - 1) + tensor_at(idx + 1)) / 2.0
return tensor
class PerforateConv2dStride(TorchApproxKnob):
r"""Simulation of strided perforated convolution for `torch.nn.Conv2d`.
Perforated convolution skips computing some entries in the output and instead interpolates
these values, to reduce the number of float-ops needed to complete a convolution op.
In this implementation, selected rows or columns of the output are discarded and replaced
with linearly interpolated values from the neighboring rows or columns. Each channel is
considered independently.
This implementation gives the same output as actual perforated convolution but without the
performance benefit.
Parameters
----------
direction_is_row : bool
If True, discard and interpolate rows, otherwise columns.
stride : int \in [2, +\infty)
Skip 1 row/column in the convolution kernel per `stride` elements.
offset : int \in [0, stride)
Skipped first row/column is `offset`.
Attributes
----------
interp_axis : int :math:`\in \{2, 3\}`
The axis that will be perforated over. As the input is an NCHW tensor, if
`direction_is_row` then `interp_axis = 2`, otherwise `interp_axis = 3`.
stride : int :math:`\in [2, +\infty)`
Equal to parameter `stride`.
offset : int :math:`\in [0, stride)`
Equal to parameter `offset`.
"""
def __init__(
self,
name: str,
direction_is_row: bool,
stride: int,
offset: int,
use_fp16: bool,
exp_speedup: float,
):
super().__init__(
name,
direction_is_row=direction_is_row,
stride=stride,
offset=offset,
use_fp16=use_fp16,
exp_speedup=exp_speedup,
)
assert stride >= 2
assert 0 <= offset < stride
self.interp_axis = 2 if direction_is_row else 3
self.stride = stride
self.offset = offset
self.fp16 = use_fp16
self.exp_speedup = exp_speedup
def is_applicable(self, op: Module) -> bool:
return isinstance(op, Conv2d)
@property
def deterministic(self) -> bool:
return True
@property
def expected_speedup(self) -> float:
return self.exp_speedup
class PerforateConv2dStrideModule(Module):
def __init__(self, conv: Conv2d, approx: "PerforateConv2dStride"):
super().__init__()
self.conv = conv
self.approx = approx
if self.approx.fp16:
self.conv = self.conv.half()
def conv_no_bias(self, x: torch.Tensor):
if self.conv.bias is None:
return self.conv(x)
bias = self.conv.bias
self.conv.bias = None
result = self.conv(x)
self.conv.bias = bias
return result
def add_conv_bias(self, conv_output: torch.Tensor):
if self.conv.bias is None:
return conv_output
broadcast_bias = self.conv.bias.reshape(1, -1, 1, 1)
return conv_output + broadcast_bias
def forward(self, x: torch.Tensor):
if self.approx.fp16:
x = x.half()
x = self.conv_no_bias(x)
assert x.dim() == 4
# Put self.approx.interp_axis to first axis temporarily
x = x.transpose(0, self.approx.interp_axis)
interp_indices = torch.tensor(
range(self.approx.offset, x.size(0), self.approx.stride)
)
x = _interpolate_first_dim(x, interp_indices)
# Putting axes back
x = x.transpose(0, self.approx.interp_axis)
x = self.add_conv_bias(x)
if self.approx.fp16:
assert x.dtype == torch.float16
return x.float()
def apply(self, module: Conv2d) -> PerforateConv2dStrideModule:
return self.PerforateConv2dStrideModule(module, self)
class Conv2dSampling(TorchApproxKnob):
r"""Simulation of sampled convolution for `torch.nn.Conv2d`.
Skips some elements of the convolution kernel in a uniform, strided manner,
to reduce the amount of float-ops needed to compute each output entry.
This implementation gives the same output as actual sampled convolution but without the
performance benefit.
Parameters
----------
skip_every: int
Skip 1 element in the convolution kernel per `skip_every` elements.
skip_offset : int :math:`\in [0, +\infty)`
Index of first element to be skipped.
For example, if `skip_every = 3` and `skip_offset = 1`, then indices skipped
will be [1, 4, 7, ...]
interp_rate : float
The weight will be compensated ("interpolated") with a ratio after skipping elements,
which is naturally equal to :math:`1 + (1 / (skip\_every - 1)`.
`interp_rate` modifies this rate to :math:`1 + (1 / (skip\_every - 1) \times interp\_rate`.
use_fp16 : bool
Whether to use fp16 weight/input or not.
"""
def __init__(
self,
name: str,
skip_every: int,
skip_offset: int,
interp_rate: float,
use_fp16: bool,
exp_speedup: float,
):
super().__init__(
name,
skip_every=skip_every,
skip_offset=skip_offset,
interp_rate=interp_rate,
use_fp16=use_fp16,
exp_speedup=exp_speedup,
)
assert skip_every >= 2 and skip_offset >= 0
self.skip_every = skip_every
self.skip_offset = skip_offset
self.interp_rate = interp_rate
self.fp16 = use_fp16
self.exp_speedup = exp_speedup
def is_applicable(self, op: Module) -> bool:
return isinstance(op, Conv2d)
@property
def deterministic(self) -> bool:
return True
@property
def expected_speedup(self) -> float:
return self.exp_speedup
@staticmethod
def sample_conv_weight(
interp_rate: float, skip_every: int, skip_offset: int, weight: torch.Tensor
):
r"""Samples (skips & interpolates) convolution kernel according to parameters.
For a given `weight` tensor of shape `(C1, C2, H, W)`, sample each output channel
(on axis 0) independently.
Flatten each output channel tensor into 1 dim.
In normal cases, set elements at indices ``range(skip_offset, C_2 * H * W, skip_every)``
to 0.
However, if `skip_every` == `h` == `w` == 3, we may end up skipping the same whole rows for
each input channel, which is undesirable.
Instead, increment the offset by 1 for each input channel.
Last, multiplies the kernel by the inverse ratio of elements dropped for an interpolation.
"""
if len(weight.shape) != 4:
raise ValueError("Conv2d weight should be 4-dimensional")
c1, c2, h, w = weight.shape
if skip_every == h == w == 3:
# Indices (0..h*w) to skip for each input channel
per_chan_skip_indices = [
range((i_chan + skip_offset) % skip_every, h * w, skip_every)
for i_chan in range(c2)
]
# Indices (0..c2*h*w) for each output channel, created by adding i*h*w for ith channel.
skip_indices = torch.tensor(
[
x + i * h * w
for i, per_chan in enumerate(per_chan_skip_indices)
for x in per_chan
]
)
else:
# Indices (0..c2*h*w) to skip for each output channel
skip_indices = torch.arange(skip_offset, c2 * h * w, skip_every)
flat_weight = weight.reshape(c1, -1).clone()
flat_weight[:, skip_indices] = 0
interp_rate = 1 + (1 / (skip_every - 1) * interp_rate)
flat_weight *= interp_rate
return flat_weight.reshape_as(weight)
def apply(self, module: Conv2d) -> Conv2d:
# Only copy the submodules, weights are still shared.
copied = module_only_deepcopy(module)
# But only write over the weight of copied version, original weight is unchanged.
copied.weight = Parameter(
self.sample_conv_weight(
self.interp_rate, self.skip_every, self.skip_offset, copied.weight
)
)
return copied
def _quantize_uint8(
tensor: torch.Tensor, range_min: float, range_max: float
) -> torch.Tensor:
"""Simulates quantization of `tensor` down to uint8, while still returning float values.
In the returned tensor the data will NOT be in [0, 255] range, but only 256 unique float
value will exist.
"""
quantize_range = 256
input_range = range_max - range_min
mul = input_range / quantize_range
# Map tensor into [0, 256] range.
affined = (tensor - range_min) / mul
# Convert tensor to int and back to float so it will have
# 256 (actually 257!; following hpvm impl) unique float values [0, 256].
# Then reverse affine it to the original range.
quanted = torch.floor(affined).to(torch.int).to(torch.float)
quanted_float = quanted * mul + range_min
# Clip tensor
return torch.clamp(quanted_float, range_min, range_max)
class PromiseSim(TorchApproxKnob):
"""Simulates analog accelerator PROMISE.
This hardware is proposed in "PROMISE: An End-to-End Design of a Programmable Mixed-Signal
Accelerator for Machine-Learning Algorithms."
"""
scaling_values = [0.75, 0.64, 0.336, 0.21, 0.168, 0.14, 0.11, 0.0784, 0.005]
def __init__(self, name: str, noise_level: int, exp_speedup: float):
super().__init__(name, noise_level=noise_level, exp_speedup=exp_speedup)
self.noise_level = noise_level
self.exp_speedup = exp_speedup
def is_applicable(self, op: Module) -> bool:
return isinstance(op, (Conv2d, Linear))
@property
def deterministic(self) -> bool:
return False
@property
def expected_speedup(self) -> float:
return self.exp_speedup
def add_promise_noise(self, tensor: torch.Tensor):
scale = self.scaling_values[self.noise_level]
noise = torch.normal(
mean=0.0, std=scale, size=tensor.size(), device=tensor.device
)
return noise * tensor + tensor
class PromiseSimModule(Module):
def __init__(self, module: Conv2d, approx: "PromiseSim"):
super().__init__()
if not hasattr(module, "conv_ranges"):
raise ValueError(
f"Quantization range of conv2d layer {module} not found"
)
self.input_r, weight_r, bias_r, self.output_r = module.conv_ranges
module.weight.data = _quantize_uint8(module.weight, *weight_r)
if module.bias is not None:
module.bias.data = _quantize_uint8(module.bias, *bias_r)
self.module = module
self.approx = approx
def forward(self, input_: torch.Tensor) -> torch.Tensor:
# Quantize input, weight, bias (see __init__), and add noise to input.
input_ = _quantize_uint8(input_, *self.input_r)
input_ = self.approx.add_promise_noise(input_)
output = self.module(input_)
# Then again, quantize output.
return _quantize_uint8(output, *self.output_r)
def apply(self, module: Conv2d, **kwargs) -> PromiseSimModule:
return self.PromiseSimModule(module, self)
class FP16Approx(TorchApproxKnob):
"""Approximates by reducing precision of layer computation to float16."""
def __init__(self, name: str, exp_speedup: float):
super().__init__(name, exp_speedup=exp_speedup)
self.exp_speedup = exp_speedup
def is_applicable(self, op: Module) -> bool:
return isinstance(op, (Conv2d, Linear))
@property
def deterministic(self) -> bool:
return True
@property
def applicable_op_types(self) -> List[Type[Module]]:
return [Conv2d, Linear]
def expected_speedup(self) -> float:
return self.exp_speedup
def is_less_approx(self, other: TorchApproxKnob) -> Optional[bool]:
return None
class FP16ApproxModule(Module):
def __init__(self, module: Module):
super().__init__()
self.module = module.half()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.module(x.half())
assert x.dtype == torch.float16
return x.float()
def apply(self, module: Module) -> FP16ApproxModule:
return self.FP16ApproxModule(module)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment