# import torch
# import torch.nn as nn
# from torch.ao.quantization.observer import MinMaxObserver
# # Step 1: Define a dummy nn.Conv2d
# x = torch.randn(1, 3, 32, 32)
# conv = nn.Conv2d(in_channels=3, out_channels=8, kernel_size=3, padding=1)
# main_output = conv(x)
# # Step 2: Attach quantization metadata to the original module
# conv.activation_dtype = "uint8"
# conv.parameter_dtype = "uint8"
# conv.activation_observer = MinMaxObserver
# conv.parameter_observer = MinMaxObserver
# # Step 3: Wrap it with QuantizedConv2d
# qconv = QuantizedConv2d(old_module=conv)
# # Step 4: Run dummy input
# output = qconv(x)
# # Step 5: Check output
# print("Output shape:", output, main_output)
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from torch.ao.quantization import FakeQuantize, MovingAverageMinMaxObserver
# # ======================================================================================
# # 1. Helper Function for Quantization Configuration
# # ======================================================================================
# def get_quant_config(bits: int, is_symmetric: bool):
# """Returns quant_min, quant_max, and torch.dtype for a given bitwidth."""
# if is_symmetric:
# # For symmetric quantization (typically for weights)
# if bits == 8:
# return -128, 127, torch.qint8
# elif bits == 4:
# return -8, 7, torch.qint8
# else:
# raise ValueError(f"Unsupported symmetric bitwidth: {bits}")
# else:
# # For asymmetric quantization (typically for activations)
# if bits == 8:
# return 0, 255, torch.quint8
# elif bits == 4:
# return 0, 15, torch.quint8
# else:
# raise ValueError(f"Unsupported asymmetric bitwidth: {bits}")
# # ======================================================================================
# # 2. Base Class for All Quantized Modules
# # ======================================================================================
# class QuantizationMixin(nn.Module):
# """
# Base mixin for custom quantized modules.
# Handles the creation of FakeQuantize modules for inputs, outputs, and parameters.
# """
# def __init__(self, old_module: nn.Module, activation_bits: int = 8, weight_bits: int = 8):
# super().__init__()
# self.old_module = old_module
# # Activation Quantizer (asymmetric)
# act_qmin, act_qmax, act_dtype = get_quant_config(bits=activation_bits, is_symmetric=False)
# self.input_quantizer = FakeQuantize(
# observer=MovingAverageMinMaxObserver, quant_min=act_qmin, quant_max=act_qmax,
# dtype=act_dtype, qscheme=torch.per_tensor_affine, reduce_range=False
# )
# self.output_quantizer = FakeQuantize(
# observer=MovingAverageMinMaxObserver, quant_min=act_qmin, quant_max=act_qmax,
# dtype=act_dtype, qscheme=torch.per_tensor_affine, reduce_range=False
# )
# # Weight Quantizer (symmetric)
# self.param_quantizers = nn.ModuleDict()
# if not list(self.old_module.named_parameters(recurse=False)):
# return
# weight_qmin, weight_qmax, weight_dtype = get_quant_config(bits=weight_bits, is_symmetric=True)
# weight_qscheme = torch.per_tensor_symmetric
# if isinstance(self.old_module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
# weight_qscheme = torch.per_channel_affine
# for name, _ in self.old_module.named_parameters(recurse=False):
# if 'weight' in name:
# self.param_quantizers[name] = FakeQuantize(
# observer=MovingAverageMinMaxObserver, quant_min=weight_qmin, quant_max=weight_qmax,
# dtype=weight_dtype, qscheme=weight_qscheme, reduce_range=False,
# ch_axis=0 if weight_qscheme == torch.per_channel_affine else -1
# )
# def forward(self, *args, **kwargs):
# raise NotImplementedError("Forward pass must be implemented by subclasses.")
# # ======================================================================================
# # 3. CONVOLUTIONAL AND LINEAR LAYERS
# # ======================================================================================
# class QuantizedConv1d(QuantizationMixin):
# def forward(self, x):
# qx = self.input_quantizer(x)
# qw = self.param_quantizers['weight'](self.old_module.weight)
# out = F.conv1d(qx, qw, self.old_module.bias, self.old_module.stride, self.old_module.padding, self.old_module.dilation, self.old_module.groups)
# return self.output_quantizer(out)
# class QuantizedConv2d(QuantizationMixin):
# def forward(self, x):
# qx = self.input_quantizer(x)
# qw = self.param_quantizers['weight'](self.old_module.weight)
# out = F.conv2d(qx, qw, self.old_module.bias, self.old_module.stride, self.old_module.padding, self.old_module.dilation, self.old_module.groups)
# return self.output_quantizer(out)
# class QuantizedConv3d(QuantizationMixin):
# def forward(self, x):
# qx = self.input_quantizer(x)
# qw = self.param_quantizers['weight'](self.old_module.weight)
# out = F.conv3d(qx, qw, self.old_module.bias, self.old_module.stride, self.old_module.padding, self.old_module.dilation, self.old_module.groups)
# return self.output_quantizer(out)
# class QuantizedLinear(QuantizationMixin):
# def forward(self, x):
# qx = self.input_quantizer(x)
# qw = self.param_quantizers['weight'](self.old_module.weight)
# out = F.linear(qx, qw, self.old_module.bias)
# return self.output_quantizer(out)
# # ======================================================================================
# # 4. ACTIVATION FUNCTIONS
# # ======================================================================================
# class QuantizedReLU(QuantizationMixin):
# def forward(self, x):
# # Note: In a fused block (Conv-BN-ReLU), the input 'x' is already quantized.
# # Calling input_quantizer again is idempotent and harmless.
# return self.output_quantizer(F.relu(self.input_quantizer(x)))
# class QuantizedGELU(QuantizationMixin):
# def forward(self, x):
# return self.output_quantizer(F.gelu(self.input_quantizer(x)))
# class QuantizedSiLU(QuantizationMixin):
# def forward(self, x):
# return self.output_quantizer(F.silu(self.input_quantizer(x)))
# # ======================================================================================
# # 5. POOLING AND PASSTHROUGH LAYERS (No quantization needed, just passthrough)
# # ======================================================================================
# class PassthroughWrapper(nn.Module):
# """A simple wrapper for layers that don't need quantization logic."""
# def __init__(self, old_module, **kwargs):
# super().__init__()
# self.old_module = old_module
# def forward(self, x):
# return self.old_module(x)
# QuantizedMaxPool2d = PassthroughWrapper
# QuantizedAdaptiveAvgPool2d = PassthroughWrapper
# QuantizedDropout = PassthroughWrapper
# QuantizedIdentity = PassthroughWrapper
# # ======================================================================================
# # 6. NORMALIZATION LAYERS
# # ======================================================================================
# #
# # !!! IMPORTANT NOTE ON BATCHNORM !!!
# #
# # A `QuantizedBatchNorm` is INTENTIONALLY OMITTED. During inference (and PTQ),
# # BatchNorm layers should be "fused" or "folded" into the preceding Conv/Linear
# # layer. You must perform this fusion on the FP32 model BEFORE applying these
# # quantization wrappers. Quantizing BatchNorm as a standalone module is a known
# # anti-pattern that severely degrades accuracy.
# #
# # Example of fusion:
# # >>> from torch.ao.quantization import fuse_modules
# # >>> model_fp32 = ...
# # >>> fuse_modules(model_fp32.conv1, model_fp32.bn1, inplace=True)
# #
# class QuantizedLayerNorm(QuantizationMixin):
# def forward(self, x):
# qx = self.input_quantizer(x)
# qw = self.param_quantizers['weight'](self.old_module.weight)
# # LayerNorm is unique; it uses the functional form with parameters
# out = F.layer_norm(qx, self.old_module.normalized_shape, qw, self.old_module.bias, self.old_module.eps)
# return self.output_quantizer(out)
# # ======================================================================================
# # 7. ELEMENT-WISE OPERATIONS
# # ======================================================================================
# class QuantizedAdd(QuantizationMixin):
# """Wrapper for element-wise addition, crucial for residual connections."""
# def __init__(self, old_module=nn.Identity(), activation_bits=8, **kwargs):
# super().__init__(old_module, activation_bits=activation_bits)
# # Need a second input quantizer
# act_qmin, act_qmax, act_dtype = get_quant_config(bits=activation_bits, is_symmetric=False)
# self.input_quantizer_2 = FakeQuantize(
# observer=MovingAverageMinMaxObserver, quant_min=act_qmin, quant_max=act_qmax, dtype=act_dtype
# )
# def forward(self, x1, x2):
# # NOTE: For perfect accuracy, both inputs should have the same scale/zero-point.
# # This requires sharing observers, which adds complexity to the replacement logic.
# qx1 = self.input_quantizer(x1)
# qx2 = self.input_quantizer_2(x2)
# return self.output_quantizer(torch.add(qx1, qx2))
# class QuantizedMul(QuantizationMixin):
# """Wrapper for element-wise multiplication."""
# def __init__(self, old_module=nn.Identity(), activation_bits=8, **kwargs):
# super().__init__(old_module, activation_bits=activation_bits)
# act_qmin, act_qmax, act_dtype = get_quant_config(bits=activation_bits, is_symmetric=False)
# self.input_quantizer_2 = FakeQuantize(
# observer=MovingAverageMinMaxObserver, quant_min=act_qmin, quant_max=act_qmax, dtype=act_dtype
# )
# def forward(self, x1, x2):
# qx1 = self.input_quantizer(x1)
# qx2 = self.input_quantizer_2(x2)
# return self.output_quantizer(torch.mul(qx1, qx2))