Source code for qtorch.quant.quant_function

import torch
from qtorch import Number, FixedPoint, BlockFloatingPoint, FloatingPoint
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.cpp_extension import load
import os

current_path = os.path.dirname(os.path.realpath(__file__))
quant_cpu = load(
    name="quant_cpu",
    sources=[
        os.path.join(current_path, "quant_cpu/quant_cpu.cpp"),
        os.path.join(current_path, "quant_cpu/bit_helper.cpp"),
        os.path.join(current_path, "quant_cpu/sim_helper.cpp"),
    ],
)

if torch.cuda.is_available():
    quant_cuda = load(
        name="quant_cuda",
        sources=[
            os.path.join(current_path, "quant_cuda/quant_cuda.cpp"),
            os.path.join(current_path, "quant_cuda/bit_helper.cu"),
            os.path.join(current_path, "quant_cuda/sim_helper.cu"),
            os.path.join(current_path, "quant_cuda/block_kernel.cu"),
            os.path.join(current_path, "quant_cuda/float_kernel.cu"),
            os.path.join(current_path, "quant_cuda/fixed_point_kernel.cu"),
            os.path.join(current_path, "quant_cuda/quant.cu"),
        ],
    )
else:
    quant_cuda = quant_cpu

__all__ = ["fixed_point_quantize", "block_quantize", "float_quantize", "quantizer"]


def assert_wl_fl(wl, fl, stage=""):
    if wl == -1 and fl != -1:
        raise ValueError("fixed point {} wl {}, fl {}".format(stage, wl, fl))


def get_module(x):
    if x.is_cuda:
        quant_module = quant_cuda
    else:
        quant_module = quant_cpu
    return quant_module


[docs]def quantizer( forward_number=None, backward_number=None, forward_rounding="stochastic", backward_rounding="stochastic", clamping_grad_zero=False, backward_hooks=[], ): """ Creates a quantization function to support quantizing forward and backward process differently. Args: - :param: forward_number (qtorch.Number, optional) : the number format used for forward quantization. if is None, the quantization would be a identity mapping. - :param: backward_number (qtorch.Number, optional) : the number format used for backward quantization. if is None, the quantization would be a identity mapping. - :param: forward_rounding (string) : rounding mode, \"stochastic\" or \"nearest\" (default: \"stochastic\") - :param: backward_rounding (string) : rounding mode, \"stochastic\" or \"nearest\" (default: \"stochastic\") - :param: clamping_grad_zero (bool) : zero out the gradient of numbers that are being clamped during forward propagation. currently requires forward_number to be a fixed point number. - :param: backward_hooks (iterable) : iterable of functions that will be applied to gradients before backward quantization. For example, this can be used to support custom scaling. Returns: A quantization function as specified (torch.Tensor -> torch.Tensor) """ for rounding in [forward_rounding, backward_rounding]: assert rounding in ["stochastic", "nearest",], "invalid rounding type {:s}".format(rounding) for num in [forward_number, backward_number]: if num != None: assert isinstance(num, Number) if clamping_grad_zero == False: if forward_rounding == "nearest": if type(forward_number) == BlockFloatingPoint: forward_quant = lambda x, quant_module: quant_module.block_quantize_nearest( x, forward_number.wl, forward_number.dim ) elif type(forward_number) == FixedPoint: forward_quant = lambda x, quant_module: quant_module.fixed_point_quantize_nearest( x, forward_number.wl, forward_number.fl, forward_number.clamp, forward_number.symmetric, ) elif type(forward_number) == FloatingPoint: forward_quant = lambda x, quant_module: quant_module.float_quantize_nearest( x, forward_number.man, forward_number.exp ) elif forward_rounding == "stochastic": if type(forward_number) == BlockFloatingPoint: forward_quant = lambda x, quant_module: quant_module.block_quantize_stochastic( x, forward_number.wl, forward_number.dim ) elif type(forward_number) == FixedPoint: forward_quant = lambda x, quant_module: quant_module.fixed_point_quantize_stochastic( x, forward_number.wl, forward_number.fl, forward_number.clamp, forward_number.symmetric, ) elif type(forward_number) == FloatingPoint: forward_quant = lambda x, quant_module: quant_module.float_quantize_stochastic( x, forward_number.man, forward_number.exp ) else: if type(forward_number) == FixedPoint or forward_number == None: assert ( forward_number == None or forward_number.clamp == True ), "must use clamping if zeroing out clamped gradient" if forward_rounding == "nearest": forward_quant = lambda x, quant_module: quant_module.fixed_point_quantize_nearest_mask( x, forward_number.wl, forward_number.fl, forward_number.symmetric ) elif forward_rounding == "stochastic": forward_quant = lambda x, quant_module: quant_module.fixed_point_quantize_stochastic_mask( x, forward_number.wl, forward_number.fl, forward_number.symmetric ) else: raise ValueError("zeroing clamping gradient only support fixed point.") if backward_rounding == "nearest": if type(backward_number) == BlockFloatingPoint: backward_quant = lambda a, quant_module: quant_module.block_quantize_nearest( a, backward_number.wl, backward_number.dim ) elif type(backward_number) == FixedPoint: backward_quant = lambda a, quant_module: quant_module.fixed_point_quantize_nearest( a, backward_number.wl, backward_number.fl, backward_number.clamp, backward_number.symmetric, ) elif type(backward_number) == FloatingPoint: backward_quant = lambda a, quant_module: quant_module.float_quantize_nearest( a, backward_number.man, backward_number.exp ) elif backward_rounding == "stochastic": if type(backward_number) == BlockFloatingPoint: backward_quant = lambda a, quant_module: quant_module.block_quantize_stochastic( a, backward_number.wl, backward_number.dim ) elif type(backward_number) == FixedPoint: backward_quant = lambda a, quant_module: quant_module.fixed_point_quantize_stochastic( a, backward_number.wl, backward_number.fl, backward_number.clamp, backward_number.symmetric, ) elif type(backward_number) == FloatingPoint: backward_quant = lambda a, quant_module: quant_module.float_quantize_stochastic( a, backward_number.man, backward_number.exp ) if clamping_grad_zero == False: class Rounding(torch.autograd.Function): @staticmethod def forward(self, x): if forward_number == None: return x quant_module = get_module(x) out = forward_quant(x.contiguous(), quant_module) return out @staticmethod def backward(self, grad_output): if self.needs_input_grad[0]: if backward_number == None: grad_input = grad_output else: quant_module = get_module(grad_output) grad_input = backward_quant(grad_output.contiguous(), quant_module) else: grad_input = None return grad_input else: class Rounding(torch.autograd.Function): @staticmethod def forward(self, x): if forward_number == None: self.mask = torch.zeros_like(x).bool() return x else: quant_module = get_module(x) out, mask = forward_quant(x.contiguous(), quant_module) self.mask = mask return out @staticmethod def backward(self, grad_output): if self.needs_input_grad[0]: if backward_number == None: grad_input = grad_output else: quant_module = get_module(grad_output) # grad_output = grad_output.contiguous().masked_fill_(self.mask, 0) for f in backward_hooks: grad_output = f(grad_output) grad_input = backward_quant(grad_output.contiguous(), quant_module).masked_fill( self.mask.bool(), 0 ) else: grad_input = None return grad_input return Rounding.apply
[docs]def fixed_point_quantize(x, wl, fl, clamp=True, symmetric=False, rounding="stochastic"): """ Quantize a single precision Floating Point into low-precision Fixed Point Args: - :param: `x` (torch.Tensor) : the single precision number to be quantized - :param: `wl` (int) : word length of the fixed point number being simulated - :param: `fl` (int) : fractional length of the fixed point number being simulated - :param: `clamp` (bool, optional) : clamp input numbers into representable range. if false, the quantization will only simulate the effect on precision - :param: `symmetric` (bool, optional) : discard the minimum representable number to make the representable range symmetric - :param: `rounding` (string) : rounding mode, \"stochastic\" or \"nearest\" (default: \"stochastic\") Returns: - a quantized low-precision block floating point number (torch.Tensor) """ assert isinstance(x, torch.Tensor) assert rounding in ["stochastic", "nearest"] assert_wl_fl(wl, fl) quant_module = get_module(x) if rounding == "nearest": out = quant_module.fixed_point_quantize_nearest(x.contiguous(), wl, fl, clamp, symmetric) elif rounding == "stochastic": out = quant_module.fixed_point_quantize_stochastic(x.contiguous(), wl, fl, clamp, symmetric) return out
[docs]def block_quantize(x, wl, dim=-1, rounding="stochastic"): """ Quantize a single precision Floating Point into low-precision Block Floating Point Args: - :param: `x` (torch.Tensor) : the single precision number to be quantized - :param: `wl` (int) : word length of the block floating point number being simulated - :param: `rounding` (string) : rounding mode, \"stochastic\" or \"nearest\" Returns: - a quantized low-precision block floating point number (torch.Tensor) """ assert isinstance(x, torch.Tensor), "x is not a single precision Floating Point Tensor" assert rounding in ["stochastic", "nearest"], "invalid rounding mode, {}".format(rounding) quant_module = get_module(x) if rounding == "nearest": out = quant_module.block_quantize_nearest(x.contiguous(), wl, dim) elif rounding == "stochastic": out = quant_module.block_quantize_stochastic(x.contiguous(), wl, dim) return out
[docs]def float_quantize(x, exp, man, rounding="stochastic"): """ Quantize a single precision Floating Point into low-precision Floating Point Args: - :attr: `x` (torch.Tensor) : the single precision number(torch.Tensor) to be quantized - :attr: `exp` (int) : number of bits allocated for exponent - :attr: `man` (int) : number of bits allocated for mantissa, not counting the virtual bit - :attr: `rounding` (string) : rounding mode, \"stochastic\" or \"nearest\" Returns: - a quantized low-precision floating point number (torch.Tensor) """ assert isinstance(x, torch.Tensor), "x is not a single precision Floating Point Tensor" assert rounding in ["stochastic", "nearest"], "invalid rounding mode, {}".format(rounding) quant_module = get_module(x) if rounding == "nearest": out = quant_module.float_quantize_nearest(x.contiguous(), man, exp) elif rounding == "stochastic": out = quant_module.float_quantize_stochastic(x.contiguous(), man, exp) return out