# import torch # import torch.nn as nn # from torch.ao.quantization.observer import MinMaxObserver # # Step 1: Define a dummy nn.Conv2d # x = torch.randn(1, 3, 32, 32) # conv = nn.Conv2d(in_channels=3, out_channels=8, kernel_size=3, padding=1) # main_output = conv(x) # # Step 2: Attach quantization metadata to the original module # conv.activation_dtype = "uint8" # conv.parameter_dtype = "uint8" # conv.activation_observer = MinMaxObserver # conv.parameter_observer = MinMaxObserver # # Step 3: Wrap it with QuantizedConv2d # qconv = QuantizedConv2d(old_module=conv) # # Step 4: Run dummy input # output = qconv(x) # # Step 5: Check output # print("Output shape:", output, main_output) # import torch # import torch.nn as nn # import torch.nn.functional as F # from torch.ao.quantization import FakeQuantize, MovingAverageMinMaxObserver # # ====================================================================================== # # 1. Helper Function for Quantization Configuration # # ====================================================================================== # def get_quant_config(bits: int, is_symmetric: bool): # """Returns quant_min, quant_max, and torch.dtype for a given bitwidth.""" # if is_symmetric: # # For symmetric quantization (typically for weights) # if bits == 8: # return -128, 127, torch.qint8 # elif bits == 4: # return -8, 7, torch.qint8 # else: # raise ValueError(f"Unsupported symmetric bitwidth: {bits}") # else: # # For asymmetric quantization (typically for activations) # if bits == 8: # return 0, 255, torch.quint8 # elif bits == 4: # return 0, 15, torch.quint8 # else: # raise ValueError(f"Unsupported asymmetric bitwidth: {bits}") # # ====================================================================================== # # 2. Base Class for All Quantized Modules # # ====================================================================================== # class QuantizationMixin(nn.Module): # """ # Base mixin for custom quantized modules. # Handles the creation of FakeQuantize modules for inputs, outputs, and parameters. # """ # def __init__(self, old_module: nn.Module, activation_bits: int = 8, weight_bits: int = 8): # super().__init__() # self.old_module = old_module # # Activation Quantizer (asymmetric) # act_qmin, act_qmax, act_dtype = get_quant_config(bits=activation_bits, is_symmetric=False) # self.input_quantizer = FakeQuantize( # observer=MovingAverageMinMaxObserver, quant_min=act_qmin, quant_max=act_qmax, # dtype=act_dtype, qscheme=torch.per_tensor_affine, reduce_range=False # ) # self.output_quantizer = FakeQuantize( # observer=MovingAverageMinMaxObserver, quant_min=act_qmin, quant_max=act_qmax, # dtype=act_dtype, qscheme=torch.per_tensor_affine, reduce_range=False # ) # # Weight Quantizer (symmetric) # self.param_quantizers = nn.ModuleDict() # if not list(self.old_module.named_parameters(recurse=False)): # return # weight_qmin, weight_qmax, weight_dtype = get_quant_config(bits=weight_bits, is_symmetric=True) # weight_qscheme = torch.per_tensor_symmetric # if isinstance(self.old_module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): # weight_qscheme = torch.per_channel_affine # for name, _ in self.old_module.named_parameters(recurse=False): # if 'weight' in name: # self.param_quantizers[name] = FakeQuantize( # observer=MovingAverageMinMaxObserver, quant_min=weight_qmin, quant_max=weight_qmax, # dtype=weight_dtype, qscheme=weight_qscheme, reduce_range=False, # ch_axis=0 if weight_qscheme == torch.per_channel_affine else -1 # ) # def forward(self, *args, **kwargs): # raise NotImplementedError("Forward pass must be implemented by subclasses.") # # ====================================================================================== # # 3. CONVOLUTIONAL AND LINEAR LAYERS # # ====================================================================================== # class QuantizedConv1d(QuantizationMixin): # def forward(self, x): # qx = self.input_quantizer(x) # qw = self.param_quantizers['weight'](self.old_module.weight) # out = F.conv1d(qx, qw, self.old_module.bias, self.old_module.stride, self.old_module.padding, self.old_module.dilation, self.old_module.groups) # return self.output_quantizer(out) # class QuantizedConv2d(QuantizationMixin): # def forward(self, x): # qx = self.input_quantizer(x) # qw = self.param_quantizers['weight'](self.old_module.weight) # out = F.conv2d(qx, qw, self.old_module.bias, self.old_module.stride, self.old_module.padding, self.old_module.dilation, self.old_module.groups) # return self.output_quantizer(out) # class QuantizedConv3d(QuantizationMixin): # def forward(self, x): # qx = self.input_quantizer(x) # qw = self.param_quantizers['weight'](self.old_module.weight) # out = F.conv3d(qx, qw, self.old_module.bias, self.old_module.stride, self.old_module.padding, self.old_module.dilation, self.old_module.groups) # return self.output_quantizer(out) # class QuantizedLinear(QuantizationMixin): # def forward(self, x): # qx = self.input_quantizer(x) # qw = self.param_quantizers['weight'](self.old_module.weight) # out = F.linear(qx, qw, self.old_module.bias) # return self.output_quantizer(out) # # ====================================================================================== # # 4. ACTIVATION FUNCTIONS # # ====================================================================================== # class QuantizedReLU(QuantizationMixin): # def forward(self, x): # # Note: In a fused block (Conv-BN-ReLU), the input 'x' is already quantized. # # Calling input_quantizer again is idempotent and harmless. # return self.output_quantizer(F.relu(self.input_quantizer(x))) # class QuantizedGELU(QuantizationMixin): # def forward(self, x): # return self.output_quantizer(F.gelu(self.input_quantizer(x))) # class QuantizedSiLU(QuantizationMixin): # def forward(self, x): # return self.output_quantizer(F.silu(self.input_quantizer(x))) # # ====================================================================================== # # 5. POOLING AND PASSTHROUGH LAYERS (No quantization needed, just passthrough) # # ====================================================================================== # class PassthroughWrapper(nn.Module): # """A simple wrapper for layers that don't need quantization logic.""" # def __init__(self, old_module, **kwargs): # super().__init__() # self.old_module = old_module # def forward(self, x): # return self.old_module(x) # QuantizedMaxPool2d = PassthroughWrapper # QuantizedAdaptiveAvgPool2d = PassthroughWrapper # QuantizedDropout = PassthroughWrapper # QuantizedIdentity = PassthroughWrapper # # ====================================================================================== # # 6. NORMALIZATION LAYERS # # ====================================================================================== # # # # !!! IMPORTANT NOTE ON BATCHNORM !!! # # # # A `QuantizedBatchNorm` is INTENTIONALLY OMITTED. During inference (and PTQ), # # BatchNorm layers should be "fused" or "folded" into the preceding Conv/Linear # # layer. You must perform this fusion on the FP32 model BEFORE applying these # # quantization wrappers. Quantizing BatchNorm as a standalone module is a known # # anti-pattern that severely degrades accuracy. # # # # Example of fusion: # # >>> from torch.ao.quantization import fuse_modules # # >>> model_fp32 = ... # # >>> fuse_modules(model_fp32.conv1, model_fp32.bn1, inplace=True) # # # class QuantizedLayerNorm(QuantizationMixin): # def forward(self, x): # qx = self.input_quantizer(x) # qw = self.param_quantizers['weight'](self.old_module.weight) # # LayerNorm is unique; it uses the functional form with parameters # out = F.layer_norm(qx, self.old_module.normalized_shape, qw, self.old_module.bias, self.old_module.eps) # return self.output_quantizer(out) # # ====================================================================================== # # 7. ELEMENT-WISE OPERATIONS # # ====================================================================================== # class QuantizedAdd(QuantizationMixin): # """Wrapper for element-wise addition, crucial for residual connections.""" # def __init__(self, old_module=nn.Identity(), activation_bits=8, **kwargs): # super().__init__(old_module, activation_bits=activation_bits) # # Need a second input quantizer # act_qmin, act_qmax, act_dtype = get_quant_config(bits=activation_bits, is_symmetric=False) # self.input_quantizer_2 = FakeQuantize( # observer=MovingAverageMinMaxObserver, quant_min=act_qmin, quant_max=act_qmax, dtype=act_dtype # ) # def forward(self, x1, x2): # # NOTE: For perfect accuracy, both inputs should have the same scale/zero-point. # # This requires sharing observers, which adds complexity to the replacement logic. # qx1 = self.input_quantizer(x1) # qx2 = self.input_quantizer_2(x2) # return self.output_quantizer(torch.add(qx1, qx2)) # class QuantizedMul(QuantizationMixin): # """Wrapper for element-wise multiplication.""" # def __init__(self, old_module=nn.Identity(), activation_bits=8, **kwargs): # super().__init__(old_module, activation_bits=activation_bits) # act_qmin, act_qmax, act_dtype = get_quant_config(bits=activation_bits, is_symmetric=False) # self.input_quantizer_2 = FakeQuantize( # observer=MovingAverageMinMaxObserver, quant_min=act_qmin, quant_max=act_qmax, dtype=act_dtype # ) # def forward(self, x1, x2): # qx1 = self.input_quantizer(x1) # qx2 = self.input_quantizer_2(x2) # return self.output_quantizer(torch.mul(qx1, qx2))