Preview:

# import torch
# import torch.nn as nn
# from torch.ao.quantization.observer import MinMaxObserver

# # Step 1: Define a dummy nn.Conv2d
# x = torch.randn(1, 3, 32, 32)
# conv = nn.Conv2d(in_channels=3, out_channels=8, kernel_size=3, padding=1)
# main_output = conv(x)
# # Step 2: Attach quantization metadata to the original module
# conv.activation_dtype = "uint8"
# conv.parameter_dtype = "uint8"
# conv.activation_observer = MinMaxObserver
# conv.parameter_observer = MinMaxObserver

# # Step 3: Wrap it with QuantizedConv2d
# qconv = QuantizedConv2d(old_module=conv)

# # Step 4: Run dummy input
# output = qconv(x)

# # Step 5: Check output
# print("Output shape:", output, main_output)


# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from torch.ao.quantization import FakeQuantize, MovingAverageMinMaxObserver

# # ======================================================================================
# # 1. Helper Function for Quantization Configuration
# # ======================================================================================

# def get_quant_config(bits: int, is_symmetric: bool):
#     """Returns quant_min, quant_max, and torch.dtype for a given bitwidth."""
#     if is_symmetric:
#         # For symmetric quantization (typically for weights)
#         if bits == 8:
#             return -128, 127, torch.qint8
#         elif bits == 4:
#             return -8, 7, torch.qint8
#         else:
#             raise ValueError(f"Unsupported symmetric bitwidth: {bits}")
#     else:
#         # For asymmetric quantization (typically for activations)
#         if bits == 8:
#             return 0, 255, torch.quint8
#         elif bits == 4:
#             return 0, 15, torch.quint8
#         else:
#             raise ValueError(f"Unsupported asymmetric bitwidth: {bits}")

# # ======================================================================================
# # 2. Base Class for All Quantized Modules
# # ======================================================================================

# class QuantizationMixin(nn.Module):
#     """
#     Base mixin for custom quantized modules.
#     Handles the creation of FakeQuantize modules for inputs, outputs, and parameters.
#     """
#     def __init__(self, old_module: nn.Module, activation_bits: int = 8, weight_bits: int = 8):
#         super().__init__()
#         self.old_module = old_module

#         # Activation Quantizer (asymmetric)
#         act_qmin, act_qmax, act_dtype = get_quant_config(bits=activation_bits, is_symmetric=False)
#         self.input_quantizer = FakeQuantize(
#             observer=MovingAverageMinMaxObserver, quant_min=act_qmin, quant_max=act_qmax,
#             dtype=act_dtype, qscheme=torch.per_tensor_affine, reduce_range=False
#         )
#         self.output_quantizer = FakeQuantize(
#             observer=MovingAverageMinMaxObserver, quant_min=act_qmin, quant_max=act_qmax,
#             dtype=act_dtype, qscheme=torch.per_tensor_affine, reduce_range=False
#         )

#         # Weight Quantizer (symmetric)
#         self.param_quantizers = nn.ModuleDict()
#         if not list(self.old_module.named_parameters(recurse=False)):
#             return

#         weight_qmin, weight_qmax, weight_dtype = get_quant_config(bits=weight_bits, is_symmetric=True)
#         weight_qscheme = torch.per_tensor_symmetric
#         if isinstance(self.old_module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
#             weight_qscheme = torch.per_channel_affine

#         for name, _ in self.old_module.named_parameters(recurse=False):
#             if 'weight' in name:
#                 self.param_quantizers[name] = FakeQuantize(
#                     observer=MovingAverageMinMaxObserver, quant_min=weight_qmin, quant_max=weight_qmax,
#                     dtype=weight_dtype, qscheme=weight_qscheme, reduce_range=False,
#                     ch_axis=0 if weight_qscheme == torch.per_channel_affine else -1
#                 )

#     def forward(self, *args, **kwargs):
#         raise NotImplementedError("Forward pass must be implemented by subclasses.")

# # ======================================================================================
# # 3. CONVOLUTIONAL AND LINEAR LAYERS
# # ======================================================================================

# class QuantizedConv1d(QuantizationMixin):
#     def forward(self, x):
#         qx = self.input_quantizer(x)
#         qw = self.param_quantizers['weight'](self.old_module.weight)
#         out = F.conv1d(qx, qw, self.old_module.bias, self.old_module.stride, self.old_module.padding, self.old_module.dilation, self.old_module.groups)
#         return self.output_quantizer(out)

# class QuantizedConv2d(QuantizationMixin):
#     def forward(self, x):
#         qx = self.input_quantizer(x)
#         qw = self.param_quantizers['weight'](self.old_module.weight)
#         out = F.conv2d(qx, qw, self.old_module.bias, self.old_module.stride, self.old_module.padding, self.old_module.dilation, self.old_module.groups)
#         return self.output_quantizer(out)

# class QuantizedConv3d(QuantizationMixin):
#     def forward(self, x):
#         qx = self.input_quantizer(x)
#         qw = self.param_quantizers['weight'](self.old_module.weight)
#         out = F.conv3d(qx, qw, self.old_module.bias, self.old_module.stride, self.old_module.padding, self.old_module.dilation, self.old_module.groups)
#         return self.output_quantizer(out)

# class QuantizedLinear(QuantizationMixin):
#     def forward(self, x):
#         qx = self.input_quantizer(x)
#         qw = self.param_quantizers['weight'](self.old_module.weight)
#         out = F.linear(qx, qw, self.old_module.bias)
#         return self.output_quantizer(out)

# # ======================================================================================
# # 4. ACTIVATION FUNCTIONS
# # ======================================================================================

# class QuantizedReLU(QuantizationMixin):
#     def forward(self, x):
#         # Note: In a fused block (Conv-BN-ReLU), the input 'x' is already quantized.
#         # Calling input_quantizer again is idempotent and harmless.
#         return self.output_quantizer(F.relu(self.input_quantizer(x)))

# class QuantizedGELU(QuantizationMixin):
#     def forward(self, x):
#         return self.output_quantizer(F.gelu(self.input_quantizer(x)))

# class QuantizedSiLU(QuantizationMixin):
#     def forward(self, x):
#         return self.output_quantizer(F.silu(self.input_quantizer(x)))

# # ======================================================================================
# # 5. POOLING AND PASSTHROUGH LAYERS (No quantization needed, just passthrough)
# # ======================================================================================

# class PassthroughWrapper(nn.Module):
#     """A simple wrapper for layers that don't need quantization logic."""
#     def __init__(self, old_module, **kwargs):
#         super().__init__()
#         self.old_module = old_module

#     def forward(self, x):
#         return self.old_module(x)

# QuantizedMaxPool2d = PassthroughWrapper
# QuantizedAdaptiveAvgPool2d = PassthroughWrapper
# QuantizedDropout = PassthroughWrapper
# QuantizedIdentity = PassthroughWrapper

# # ======================================================================================
# # 6. NORMALIZATION LAYERS
# # ======================================================================================

# #
# # !!! IMPORTANT NOTE ON BATCHNORM !!!
# #
# # A `QuantizedBatchNorm` is INTENTIONALLY OMITTED. During inference (and PTQ),
# # BatchNorm layers should be "fused" or "folded" into the preceding Conv/Linear
# # layer. You must perform this fusion on the FP32 model BEFORE applying these
# # quantization wrappers. Quantizing BatchNorm as a standalone module is a known
# # anti-pattern that severely degrades accuracy.
# #
# # Example of fusion:
# # >>> from torch.ao.quantization import fuse_modules
# # >>> model_fp32 = ...
# # >>> fuse_modules(model_fp32.conv1, model_fp32.bn1, inplace=True)
# #

# class QuantizedLayerNorm(QuantizationMixin):
#     def forward(self, x):
#         qx = self.input_quantizer(x)
#         qw = self.param_quantizers['weight'](self.old_module.weight)
#         # LayerNorm is unique; it uses the functional form with parameters
#         out = F.layer_norm(qx, self.old_module.normalized_shape, qw, self.old_module.bias, self.old_module.eps)
#         return self.output_quantizer(out)

# # ======================================================================================
# # 7. ELEMENT-WISE OPERATIONS
# # ======================================================================================

# class QuantizedAdd(QuantizationMixin):
#     """Wrapper for element-wise addition, crucial for residual connections."""
#     def __init__(self, old_module=nn.Identity(), activation_bits=8, **kwargs):
#         super().__init__(old_module, activation_bits=activation_bits)
#         # Need a second input quantizer
#         act_qmin, act_qmax, act_dtype = get_quant_config(bits=activation_bits, is_symmetric=False)
#         self.input_quantizer_2 = FakeQuantize(
#             observer=MovingAverageMinMaxObserver, quant_min=act_qmin, quant_max=act_qmax, dtype=act_dtype
#         )

#     def forward(self, x1, x2):
#         # NOTE: For perfect accuracy, both inputs should have the same scale/zero-point.
#         # This requires sharing observers, which adds complexity to the replacement logic.
#         qx1 = self.input_quantizer(x1)
#         qx2 = self.input_quantizer_2(x2)
#         return self.output_quantizer(torch.add(qx1, qx2))

# class QuantizedMul(QuantizationMixin):
#     """Wrapper for element-wise multiplication."""
#     def __init__(self, old_module=nn.Identity(), activation_bits=8, **kwargs):
#         super().__init__(old_module, activation_bits=activation_bits)
#         act_qmin, act_qmax, act_dtype = get_quant_config(bits=activation_bits, is_symmetric=False)
#         self.input_quantizer_2 = FakeQuantize(
#             observer=MovingAverageMinMaxObserver, quant_min=act_qmin, quant_max=act_qmax, dtype=act_dtype
#         )

#     def forward(self, x1, x2):
#         qx1 = self.input_quantizer(x1)
#         qx2 = self.input_quantizer_2(x2)
#         return self.output_quantizer(torch.mul(qx1, qx2))
downloadDownload PNG downloadDownload JPEG downloadDownload SVG

Tip: You can change the style, width & colours of the snippet with the inspect tool before clicking Download!

Click to optimize width for Twitter