Preview:
import json
import os
from typing import Dict

import torch
import torch.fx as fx
import torch.nn as nn
from core.frameworks.pytorch.quantization.sdk.graph_utils import ModelGraphUtils
from core.frameworks.pytorch.quantization.sdk.model_preparer import (
    _get_module_for_dotted_name,
    prepare_model,
)
from core.frameworks.pytorch.quantization.sdk.modules.fake_quantize import (
    FakeQuantize as CustomFakeQuantize,
)
from core.frameworks.pytorch.quantization.sdk.modules.module_replacement_registry import (
    ModuleReplacements,
)
from core.frameworks.pytorch.quantization.sdk.modules.quantized_modules import (
    QuantizationMixin,
)
from core.frameworks.pytorch.quantization.sdk.quantization_utils import (
    build_clean_onnx_path,
    build_onnx_path_from_torch_name,
    get_bitwidth_from_string_dtype,
    normalize_class_name,
    print_quant_summary,
)
from core.frameworks.pytorch.utils.utils import is_leaf_module
from logger.logger_registry import get_logger
from utils.config_utils import (
    get_model_configs,
    get_task_configs,
)
from utils.helpers import set_nested_attr

from .modules.function_modules import REPLACEMENTS
from .quantsim_config.quantsim_parser import QuantSimConfigParser

logger = get_logger("main")


class QuantizerEnginer:
    def __init__(
        self,
        model: torch.nn.Module,
        configs,
        device: torch.device,
    ):
        self._model: torch.nn.Module = model
        self.config = configs
        self.logger = get_logger("main")
        self.device = device
        self.task_type = (
            "qft" if self.config["qft"] else "ptq" if self.config["ptq"] else "pruning"
        )

        self.quantsim_config = QuantSimConfigParser()

    def update_model(self, model: torch.nn.Module):
        self._model = model

    def fuse_all_conv_bn(self, module):
        fusion_list = []
        module_list = list(module.named_modules())

        for (name1, mod1), (name2, mod2) in zip(module_list, module_list[1:]):
            if isinstance(mod1, torch.nn.Conv2d) and isinstance(
                mod2, torch.nn.BatchNorm2d
            ):
                fusion_list.append([name1, name2])

        if fusion_list:
            torch.ao.quantization.fuse_modules(module, fusion_list, inplace=True)

    def _check_super_groups_module(self, layer_name: str) -> bool:
        super_groups = self.quantsim_config.get_super_groups_ops()
        next_node_name = self.graphutils.get_next_non_identity_module(layer_name)
        if not next_node_name:
            return self.quantsim_config.is_default_ops_output_quantized()
        layer_name = self.graphutils.get_normalized_layer_name(layer_name)
        next_node_name = self.graphutils.get_normalized_layer_name(next_node_name)
        if next_node_name:
            next_node_name_lower = next_node_name.lower()
            node_name_lower = layer_name.lower()
            for super_group in super_groups:
                super_group_first = super_group[0].lower()
                super_group_second = super_group[1].lower()
                if (
                    node_name_lower in super_group_first
                    and next_node_name_lower in super_group_second
                ):
                    return False
                elif next_node_name_lower in super_group_second:
                    return True
        return self.quantsim_config.is_default_ops_output_quantized()

    def apply_io_quantizer_flags_to_graph(
        self, graph: torch.fx.GraphModule
    ) -> dict[str, tuple[bool, bool]]:
        """
        Iterate over all nodes in the graph and determine IO quantizer flags.

        Returns:
            A dictionary mapping node names to (add_input_quantizer, add_output_quantizer) flags.
        """
        io_flags = {}
        for node in graph.graph.nodes:
            if node.op == "call_module":
                flags = self.__get_io_quantizer_flags(graph, node, None)
                io_flags[node.name] = flags
                print(
                    f"[IO Quantizer Flags] {node.name}: Input={flags[0]}, Output={flags[1]}"
                )
                module: QuantizationMixin = _get_module_for_dotted_name(
                    graph, node.target
                )
                if not isinstance(module, QuantizationMixin):
                    continue
                if not flags[0]:
                    module.disable_input_quantization()
                if not flags[1]:
                    module.disable_output_quantization()
                if flags[0]:
                    module.enable_input_quantization()
                if flags[1]:
                    module.enable_output_quantization()

        return io_flags

    def _check_super_groups_node(
        self,
        layer_name: str,
        node: torch.fx.Node,
        next_node: torch.fx.Node,
        next_node_name: str,
    ) -> bool:
        super_groups = self.quantsim_config.get_super_groups_ops()
        for super_group in super_groups:
            if (
                layer_name in super_group[0].lower()
                and next_node
                and next_node_name in super_group[1].lower()
            ):
                print("#" * 20, layer_name, next_node_name, "#" * 20)
                return False
            elif layer_name in super_group[1].lower():
                return True
        return True

    def __is_quantizable_module(self, module: torch.nn.Module):
        return ModuleReplacements.get_replacement(module) is not None

    @classmethod
    def __is_quantized_module(cls, module: torch.nn.Module):
        return isinstance(module, QuantizationMixin)

    def _normalized_op_name(self, module: torch.nn.Module) -> str:
        """Map a module to the normalized op name used in your config."""
        return self.graphutils.get_normalized_layer_name(module)

    def _match_pattern_from(
        self,
        graph: torch.fx.GraphModule,
        start_node: torch.fx.Node,
        pattern_ops: list[str],
    ):
        """
        Try to match a supergroup pattern starting at start_node.
        Returns list of nodes if matched, else None.
        """
        if start_node.op != "call_module":
            return None

        try:
            first_module = _get_module_for_dotted_name(graph, start_node.target)
        except Exception:
            return None

        first_name = self._normalized_op_name(first_module)
        if first_name != pattern_ops[0]:
            return None

        matched_nodes = [start_node]
        curr = start_node
        for expected in pattern_ops[1:]:
            next_node = self.graphutils.get_next_non_identity_node(curr)
            if not next_node or next_node.op != "call_module":
                return None
            try:
                next_module = _get_module_for_dotted_name(graph, next_node.target)
            except Exception:
                return None
            next_name = self._normalized_op_name(next_module)
            if next_name != expected:
                return None
            matched_nodes.append(next_node)
            curr = next_node

        return matched_nodes

    def _collect_supergroup_patterns(self):
        """
        Read patterns from your quantsim config.
        Expected structure: [["Conv", "BatchNorm", "Relu"], ["MatMul", "Add"]]
        """
        cfg = self.quantsim_config.get_model_quantization_config()
        patterns = cfg.get("supergroups", [])
        cleaned = []
        for pat in patterns:
            if isinstance(pat, (list, tuple)) and len(pat) >= 2:
                cleaned.append([str(x) for x in pat])
        return cleaned

    def _apply_super_groups_config(self, graph: torch.fx.GraphModule):
        """
        Find and apply super-group quantization configuration.
        Works for groups of any length >= 2. Prevents overlaps.
        """
        patterns = self._collect_supergroup_patterns()
        self._supergroup_members = {}
        claimed = set()
        group_id = 0

        for node in graph.graph.nodes:
            if node.op != "call_module" or node in claimed:
                continue

            for pat in patterns:
                match = self._match_pattern_from(graph, node, pat)
                if not match:
                    continue
                if any(n in claimed for n in match):
                    continue

                size = len(match)
                for idx, member in enumerate(match):
                    self._supergroup_members[member] = (group_id, idx, size)
                    claimed.add(member)

                # Gather actual modules
                modules = []
                for mnode in match:
                    m = _get_module_for_dotted_name(graph, mnode.target)
                    modules.append(m)

                # Apply quantizer sharing
                self._apply_super_group_action_general(modules)

                if getattr(self, "verbose", False):
                    names = [self._normalized_op_name(m) for m in modules]
                    print(f"[SuperGroup] id={group_id} matched: {' -> '.join(names)}")

                group_id += 1
                break

    def _belongs_to_super_group(self, node: torch.fx.Node) -> bool:
        return hasattr(self, "_supergroup_members") and node in self._supergroup_members

    def _supergroup_position(self, node: torch.fx.Node):
        """Return (group_id, idx, size) if node is in a supergroup, else None."""
        if self._belongs_to_super_group(node):
            return self._supergroup_members[node]
        return None

    def _apply_super_group_action_general(self, modules: list):
        n = len(modules)
        if n < 2:
            return

        last_module = modules[-1]
        if not hasattr(last_module, "output_quantizer"):
            return

        shared_output_q = last_module.output_quantizer

        for m in modules[:-1]:
            if hasattr(m, "output_quantizer"):
                m.output_quantizer = shared_output_q

    def __get_io_quantizer_flags(
        self, graph, node: torch.fx.Node, layer_name: str | None
    ) -> tuple[bool, bool]:
        """
        Decide whether to add input/output quantizers for a node,
        respecting model IO policy and per-layer overrides.
        """
        layer_name = layer_name or self.graphutils.get_normalized_layer_name(node)
        model_quant_config = self.quantsim_config.get_model_quantization_config()

        add_input_quantizer = self.quantsim_config.is_default_ops_input_quantized()
        add_output_quantizer = self.quantsim_config.is_default_ops_output_quantized()

        if self.graphutils._is_first(node, layer_name):
            add_input_quantizer = model_quant_config.get("input_quantized", False)
        if self.graphutils._is_last(node, layer_name):
            add_output_quantizer = model_quant_config.get("output_quantized", False)

        # No need for explicit supergroup override:
        # all sharing is already handled in _apply_super_group_action_general.

        add_input_quantizer = self._apply_layer_override(
            layer_name, add_input_quantizer, is_input=True
        )
        add_output_quantizer = self._apply_layer_override(
            layer_name, add_output_quantizer, is_input=False
        )
        _module = _get_module_for_dotted_name(graph, node.target)
        return add_input_quantizer, add_output_quantizer

    def _apply_layer_override(
        self, layer_name: str, current_value: bool, is_input: bool
    ) -> bool:
        if is_input:
            if (
                layer_name
                in self.quantsim_config.get_layers_to_skip_from_input_quantizers()
            ):
                return False
            if layer_name in self.quantsim_config.get_layers_to_add_input_quantizers():
                return True
        else:
            if (
                layer_name
                in self.quantsim_config.get_layers_to_skip_from_output_quantizers()
            ):
                return False
            if layer_name in self.quantsim_config.get_layers_to_add_output_quantizers():
                return True
        return current_value

    def _add_quantization_wrappers(self, module, prefix=""):

        if self.__is_quantized_module(module):
            return
        for module_name, module_ref in module.named_children():
            full_name = f"{prefix}.{module_name}" if prefix else module_name
            self.logger.info("nn.Module found : %s", module_ref)
            print("nn.Module found : %s", module_ref)

            if self.__is_quantizable_module(module_ref) and is_leaf_module(module_ref):
                quantized_module = self._create_quantizer_module(module_ref, full_name)
                if not quantized_module:
                    self.logger.info(f"Please register {full_name}")
                    continue
                setattr(module, module_name, quantized_module)
            else:
                self._add_quantization_wrappers(module_ref, prefix=full_name)

    def _create_quantizer_module(
        self, module_to_quantize: torch.nn.Module, module_name: str
    ):

        param_per_channel = get_task_configs(
            self.config, "ptq", "parameter_per_channel", False
        )
        act_per_channel = get_task_configs(
            self.config, "ptq", "activation_per_channel", False
        )
        param_per_tensor = get_task_configs(
            self.config, "ptq", "parameter_per_tensor", True
        )
        act_per_tensor = get_task_configs(
            self.config, "ptq", "activation_per_tensor", True
        )
        param_is_symmetric = get_task_configs(
            self.config, "ptq", "parameter_is_symmetric", True
        )
        act_is_symmetric = get_task_configs(
            self.config, "ptq", "activation_is_symmetric", False
        )

        global_activation_dtype = get_task_configs(
            self.config, "ptq", "activation_dtype", "int4"
        )
        global_param_dtype = get_task_configs(
            self.config, "ptq", "parameter_dtype", "int4"
        )
        global_activation_observer = get_task_configs(
            self.config, "ptq", "activation_observer"
        )
        global_weight_observer = get_task_configs(self.config, "ptq", "weight_observer")

        quantizer = ModuleReplacements.get_replacement(module_to_quantize)

        if not quantizer:
            self.logger.info(f"Please register {type(module_to_quantize)}")
            return None

        # registering parameter and activation dtype
        setattr(module_to_quantize, "activation_dtype", global_activation_dtype)
        setattr(module_to_quantize, "parameter_dtype", global_param_dtype)
        setattr(module_to_quantize, "parameter_observer", global_weight_observer)
        setattr(module_to_quantize, "activation_observer", global_activation_observer)
        setattr(module_to_quantize, "param_per_channel", param_per_channel)
        setattr(module_to_quantize, "act_per_channel", act_per_channel)
        setattr(module_to_quantize, "param_per_tensor", param_per_tensor)
        setattr(module_to_quantize, "act_per_tensor", act_per_tensor)
        setattr(module_to_quantize, "param_is_symmetric", param_is_symmetric)
        setattr(module_to_quantize, "act_is_symmetric", act_is_symmetric)

        self.logger.info(
            f"Replacing {type(module_to_quantize)} with {quantizer.__name__} for quantization"
        )
        quantized_module = quantizer(
            _module_to_wrap=module_to_quantize,
            # add_input_quantizers=add_input_quantizers,
            # add_output_quantizers=add_output_quantizers,
        )
        return quantized_module

    def prepare(
        self,
    ):
        """
        Recursively replace every weight-bearing layer with a QuantWrapper.

        Args:
            layer_types: if provided, only wrap these types; else wrap all with 'weight'.
            model:       internal use for recursion (initially None → uses self._model).
        """
        self.logger.info("=" * 60)
        self.logger.info("Preparing model for QFT (Quantization-Aware Fine-Tuning)")
        self.logger.info("=" * 60)
        self._model.eval()
        self.fuse_all_conv_bn(self._model)

        # try:
        self.traced_graph = fx.symbolic_trace(self._model)
        self.graphutils = ModelGraphUtils(self._model, self.traced_graph)
        self._model = prepare_model(self._model)
        self.graphutils.update_model(self._model)
        self._add_quantization_wrappers(self._model)
        self.graphutils.update_graph(self._model)
        self.apply_io_quantizer_flags_to_graph(self._model)
        self.logger.info(self._model)

        # except Exception as e:
        #     print(e)
        #     self.graphutils = ModelGraphUtils(self._model, None)
        #     self.logger.info("Model is not graph tracable. Cannot replace math ops.")
        #     self._add_quantization_wrappers(self._model)

        self.logger.info("Model after preparing")
        self.logger.info(self._model)
        print_quant_summary(self._model)
        return self._model

    def convert(self, model) -> None:
        """
        Convert the model to a quantized model.
        This method is called after the training and preparation steps.
        """

        for name, child in list(model.named_children()):
            if isinstance(child, QuantizationMixin):
                self.logger.info(f"Replacing Quantized module: {name}")
                module: QuantizationMixin = child
                scale_fp = module.get_scale_fp()
                wrapped_module: nn.Module = module._module_to_wrap
                for k, v in scale_fp.items():
                    wrapped_module.register_buffer(f"{k}_scale", v["scale"])
                    wrapped_module.register_buffer(f"{k}_zero_point", v["zero_point"])
                    setattr(wrapped_module, f"{k}_scale", v["scale"])
                    setattr(wrapped_module, f"{k}_zero_point", v["zero_point"])

                setattr(model, name, wrapped_module)
                continue

            self.convert(child)

    def export_model(self, model: torch.nn.Module, task_type: str) -> None:
        """
        Export the quantized model to a format suitable for deployment.
        """
        self.model_name = get_model_configs(self.config)["name"]
        export_format = self.config["export"]["format"]
        output_dir = self.config["export"]["output_dir"] + "/" + self.model_name
        if export_format == "onnx":
            export_model = model.apply(torch.ao.quantization.disable_observer)
            export_model.cpu()
            if not os.path.exists(output_dir):
                os.makedirs(output_dir, exist_ok=True)

            param_dtype = get_task_configs(
                self.config, task_type, "parameter_dtype", "int4"
            )
            opset = get_task_configs(self.config, "export", "opset", 13)
            attribute = get_task_configs(self.config, "export", "attribute", None)
            if attribute is not None:
                if not hasattr(export_model, attribute):
                    raise ValueError(f"Model has no attribute '{attribute}' for export")
                export_model = getattr(export_model, attribute)

            onnx_file = (
                f"{self.model_name}_quantized_model_{task_type}_{param_dtype}.onnx"
            )

            output_path = os.path.join(output_dir, onnx_file)

            self.logger.info("=" * 60)
            self.logger.info("Exporting the quantized model")
            self.logger.info("=" * 60)
            self.logger.info(f"Model Name      : {self.model_name}")
            self.logger.info(f"Export Format   : {export_format}")
            self.logger.info(f"Output Path     : {output_path}")
            self.logger.info("=" * 60)
            try:
                torch.onnx.export(
                    export_model,
                    torch.randn(self._input_shape),  # type: ignore
                    output_path,
                    export_params=True,
                    opset_version=opset,
                    do_constant_folding=True,
                    input_names=["input_image"],
                    output_names=["output"],
                )
                self.logger.info("Model export completed successfully.")
            except Exception as e:
                self.logger.error("Model export failed.")
                self.logger.exception(e)

    def __extract_encoding(self, module: CustomFakeQuantize) -> Dict:
        scale = module.scale
        zero_point = module.zero_point
        quant_min = module.quant_min
        quant_max = module.quant_max
        qscheme = module.qscheme
        dtype = str(module.dtype)

        bitwidth = get_bitwidth_from_string_dtype(dtype)

        is_symmetric = qscheme in [
            torch.per_tensor_symmetric,
            torch.per_channel_symmetric,
        ]

        if is_symmetric:
            encoding_min = -scale * ((quant_max - quant_min) / 2)
            encoding_max = scale * ((quant_max - quant_min) / 2)
        else:
            encoding_min = scale * (quant_min - zero_point)
            encoding_max = scale * (quant_max - zero_point)

        base_info = {
            "bitwidth": bitwidth,
            "quant_min": quant_min,
            "quant_max": quant_max,
            "qscheme": str(qscheme),
            "dtype": dtype,
            "is_symmetric": is_symmetric,
        }

        if scale.numel() == 1:
            return {
                **base_info,
                "encodings": [
                    {
                        "scale": scale.item(),
                        "offset": zero_point.item(),
                        "min": encoding_min.item(),
                        "max": encoding_max.item(),
                    }
                ],
            }

        else:
            encodings = []
            for i in range(scale.numel()):
                encodings.append(
                    {
                        "scale": scale[i].item(),
                        "offset": zero_point[i].item(),
                        "min": encoding_min[i].item(),
                        "max": encoding_max[i].item(),
                    }
                )
        return {**base_info, "encodings": encodings}

    def __get_quant_min_max(self, dtype):
        dtype = dtype.lower()
        if dtype == "int4":
            return -8, 7, torch.qint8
        elif dtype == "uint4":
            return 0, 15, torch.quint8
        elif dtype == "int8":
            return -128, 127, torch.qint8
        elif dtype == "uint8":
            return 0, 255, torch.quint8
        elif dtype == "int16":
            return -32768, 32767, torch.qint32
        #     return 0, 65535, torch.qint32
        else:
            raise ValueError(f"Unsupported dtype: {dtype}")

    def __get_encodings_from_model(self, model, encoding_dict=None) -> Dict[str, Dict]:
        """
        Extracts quantization parameters (scale, zero_point, min, max, etc.)
        from a model prepared using torch.ao.quantization.prepare_qat.

        Returns a hierarchical dictionary of encodings per FakeQuantize module.
        """
        encoding_dict = {}

        for name, child in model.named_modules():
            if not name:
                continue
            if "quant" in name:
                continue
            onnx_path = build_onnx_path_from_torch_name(name)
            cls_name = normalize_class_name(child.__class__.__name__)
            output_name = build_clean_onnx_path(f"{onnx_path}/{cls_name}")
            output_name = output_name.replace("Quantized", "")
            if "module_" in output_name:
                output_name.replace("module_", "")
            if isinstance(child, QuantizationMixin):
                data = {}
                for sub_name, sub_module in child.named_modules():
                    if isinstance(sub_module, (CustomFakeQuantize)):
                        encoding_info = self.__extract_encoding(sub_module)
                        data[f"{sub_name.replace('_quantizers','')}"] = encoding_info
                encoding_dict[output_name] = data

        return encoding_dict

    def generate_embeddings(self, attribute):
        """
        Iterates through model modules, collects quantization encodings from QuantizationMixin modules,
        and writes them to a JSON file as a list.
        """
        output_dir = (
            self.config["export"]["output_dir"]
            + "/"
            + get_model_configs(self.config, "name")
        )
        # if attribute is not None:
        #     if not hasattr(model, attribute):
        #         raise ValueError(f"Model has no attribute '{attribute}' for export")
        #     model = getattr(model, attribute)

        all_encodings = self.__get_encodings_from_model(
            self._model
            # if  isinstance(attribute, str) and not hasattr(self._model, attribute)
            # else  isinstance(attribute, str) and getattr(self._model, attribute)
        )

        task_type = (
            "qft"
            if get_model_configs(self.config, "qft")
            else "ptq" if get_model_configs(self.config, "ptq") else "pruning"
        )
        output_path = os.path.join(
            output_dir,
            f'{get_model_configs(self.config, "name")}_{task_type}_quantization_encodings.json',
        )
        os.makedirs(output_dir, exist_ok=True)

        with open(output_path, "w") as f:
            json.dump(all_encodings, f, indent=4)

        print(f"Quantization encodings saved to: {output_path}")
Update all method according to my quantsimparse
ignore aimet
i have sent one quantsimparser class right
some methods that you have used here is wrong
downloadDownload PNG downloadDownload JPEG downloadDownload SVG

Tip: You can change the style, width & colours of the snippet with the inspect tool before clicking Download!

Click to optimize width for Twitter