dadsa
Tue Aug 19 2025 18:40:28 GMT+0000 (Coordinated Universal Time)
Saved by @reiddd #javascript
import json import os from typing import Dict import torch import torch.fx as fx import torch.nn as nn from core.frameworks.pytorch.quantization.sdk.graph_utils import ModelGraphUtils from core.frameworks.pytorch.quantization.sdk.model_preparer import ( _get_module_for_dotted_name, prepare_model, ) from core.frameworks.pytorch.quantization.sdk.modules.fake_quantize import ( FakeQuantize as CustomFakeQuantize, ) from core.frameworks.pytorch.quantization.sdk.modules.module_replacement_registry import ( ModuleReplacements, ) from core.frameworks.pytorch.quantization.sdk.modules.quantized_modules import ( QuantizationMixin, ) from core.frameworks.pytorch.quantization.sdk.quantization_utils import ( build_clean_onnx_path, build_onnx_path_from_torch_name, get_bitwidth_from_string_dtype, normalize_class_name, print_quant_summary, ) from core.frameworks.pytorch.utils.utils import is_leaf_module from logger.logger_registry import get_logger from utils.config_utils import ( get_model_configs, get_task_configs, ) from utils.helpers import set_nested_attr from .modules.function_modules import REPLACEMENTS from .quantsim_config.quantsim_parser import QuantSimConfigParser logger = get_logger("main") class QuantizerEnginer: def __init__( self, model: torch.nn.Module, configs, device: torch.device, ): self._model: torch.nn.Module = model self.config = configs self.logger = get_logger("main") self.device = device self.task_type = ( "qft" if self.config["qft"] else "ptq" if self.config["ptq"] else "pruning" ) self.quantsim_config = QuantSimConfigParser() def update_model(self, model: torch.nn.Module): self._model = model def fuse_all_conv_bn(self, module): fusion_list = [] module_list = list(module.named_modules()) for (name1, mod1), (name2, mod2) in zip(module_list, module_list[1:]): if isinstance(mod1, torch.nn.Conv2d) and isinstance( mod2, torch.nn.BatchNorm2d ): fusion_list.append([name1, name2]) if fusion_list: torch.ao.quantization.fuse_modules(module, fusion_list, inplace=True) def _check_super_groups_module(self, layer_name: str) -> bool: super_groups = self.quantsim_config.get_super_groups_ops() next_node_name = self.graphutils.get_next_non_identity_module(layer_name) if not next_node_name: return self.quantsim_config.is_default_ops_output_quantized() layer_name = self.graphutils.get_normalized_layer_name(layer_name) next_node_name = self.graphutils.get_normalized_layer_name(next_node_name) if next_node_name: next_node_name_lower = next_node_name.lower() node_name_lower = layer_name.lower() for super_group in super_groups: super_group_first = super_group[0].lower() super_group_second = super_group[1].lower() if ( node_name_lower in super_group_first and next_node_name_lower in super_group_second ): return False elif next_node_name_lower in super_group_second: return True return self.quantsim_config.is_default_ops_output_quantized() def apply_io_quantizer_flags_to_graph( self, graph: torch.fx.GraphModule ) -> dict[str, tuple[bool, bool]]: """ Iterate over all nodes in the graph and determine IO quantizer flags. Returns: A dictionary mapping node names to (add_input_quantizer, add_output_quantizer) flags. """ io_flags = {} for node in graph.graph.nodes: if node.op == "call_module": flags = self.__get_io_quantizer_flags(graph, node, None) io_flags[node.name] = flags print( f"[IO Quantizer Flags] {node.name}: Input={flags[0]}, Output={flags[1]}" ) module: QuantizationMixin = _get_module_for_dotted_name( graph, node.target ) if not isinstance(module, QuantizationMixin): continue if not flags[0]: module.disable_input_quantization() if not flags[1]: module.disable_output_quantization() if flags[0]: module.enable_input_quantization() if flags[1]: module.enable_output_quantization() return io_flags def _check_super_groups_node( self, layer_name: str, node: torch.fx.Node, next_node: torch.fx.Node, next_node_name: str, ) -> bool: super_groups = self.quantsim_config.get_super_groups_ops() for super_group in super_groups: if ( layer_name in super_group[0].lower() and next_node and next_node_name in super_group[1].lower() ): print("#" * 20, layer_name, next_node_name, "#" * 20) return False elif layer_name in super_group[1].lower(): return True return True def __is_quantizable_module(self, module: torch.nn.Module): return ModuleReplacements.get_replacement(module) is not None @classmethod def __is_quantized_module(cls, module: torch.nn.Module): return isinstance(module, QuantizationMixin) def _normalized_op_name(self, module: torch.nn.Module) -> str: """Map a module to the normalized op name used in your config.""" return self.graphutils.get_normalized_layer_name(module) def _match_pattern_from( self, graph: torch.fx.GraphModule, start_node: torch.fx.Node, pattern_ops: list[str], ): """ Try to match a supergroup pattern starting at start_node. Returns list of nodes if matched, else None. """ if start_node.op != "call_module": return None try: first_module = _get_module_for_dotted_name(graph, start_node.target) except Exception: return None first_name = self._normalized_op_name(first_module) if first_name != pattern_ops[0]: return None matched_nodes = [start_node] curr = start_node for expected in pattern_ops[1:]: next_node = self.graphutils.get_next_non_identity_node(curr) if not next_node or next_node.op != "call_module": return None try: next_module = _get_module_for_dotted_name(graph, next_node.target) except Exception: return None next_name = self._normalized_op_name(next_module) if next_name != expected: return None matched_nodes.append(next_node) curr = next_node return matched_nodes def _collect_supergroup_patterns(self): """ Read patterns from your quantsim config. Expected structure: [["Conv", "BatchNorm", "Relu"], ["MatMul", "Add"]] """ cfg = self.quantsim_config.get_model_quantization_config() patterns = cfg.get("supergroups", []) cleaned = [] for pat in patterns: if isinstance(pat, (list, tuple)) and len(pat) >= 2: cleaned.append([str(x) for x in pat]) return cleaned def _apply_super_groups_config(self, graph: torch.fx.GraphModule): """ Find and apply super-group quantization configuration. Works for groups of any length >= 2. Prevents overlaps. """ patterns = self._collect_supergroup_patterns() self._supergroup_members = {} claimed = set() group_id = 0 for node in graph.graph.nodes: if node.op != "call_module" or node in claimed: continue for pat in patterns: match = self._match_pattern_from(graph, node, pat) if not match: continue if any(n in claimed for n in match): continue size = len(match) for idx, member in enumerate(match): self._supergroup_members[member] = (group_id, idx, size) claimed.add(member) # Gather actual modules modules = [] for mnode in match: m = _get_module_for_dotted_name(graph, mnode.target) modules.append(m) # Apply quantizer sharing self._apply_super_group_action_general(modules) if getattr(self, "verbose", False): names = [self._normalized_op_name(m) for m in modules] print(f"[SuperGroup] id={group_id} matched: {' -> '.join(names)}") group_id += 1 break def _belongs_to_super_group(self, node: torch.fx.Node) -> bool: return hasattr(self, "_supergroup_members") and node in self._supergroup_members def _supergroup_position(self, node: torch.fx.Node): """Return (group_id, idx, size) if node is in a supergroup, else None.""" if self._belongs_to_super_group(node): return self._supergroup_members[node] return None def _apply_super_group_action_general(self, modules: list): n = len(modules) if n < 2: return last_module = modules[-1] if not hasattr(last_module, "output_quantizer"): return shared_output_q = last_module.output_quantizer for m in modules[:-1]: if hasattr(m, "output_quantizer"): m.output_quantizer = shared_output_q def __get_io_quantizer_flags( self, graph, node: torch.fx.Node, layer_name: str | None ) -> tuple[bool, bool]: """ Decide whether to add input/output quantizers for a node, respecting model IO policy and per-layer overrides. """ layer_name = layer_name or self.graphutils.get_normalized_layer_name(node) model_quant_config = self.quantsim_config.get_model_quantization_config() add_input_quantizer = self.quantsim_config.is_default_ops_input_quantized() add_output_quantizer = self.quantsim_config.is_default_ops_output_quantized() if self.graphutils._is_first(node, layer_name): add_input_quantizer = model_quant_config.get("input_quantized", False) if self.graphutils._is_last(node, layer_name): add_output_quantizer = model_quant_config.get("output_quantized", False) # No need for explicit supergroup override: # all sharing is already handled in _apply_super_group_action_general. add_input_quantizer = self._apply_layer_override( layer_name, add_input_quantizer, is_input=True ) add_output_quantizer = self._apply_layer_override( layer_name, add_output_quantizer, is_input=False ) _module = _get_module_for_dotted_name(graph, node.target) return add_input_quantizer, add_output_quantizer def _apply_layer_override( self, layer_name: str, current_value: bool, is_input: bool ) -> bool: if is_input: if ( layer_name in self.quantsim_config.get_layers_to_skip_from_input_quantizers() ): return False if layer_name in self.quantsim_config.get_layers_to_add_input_quantizers(): return True else: if ( layer_name in self.quantsim_config.get_layers_to_skip_from_output_quantizers() ): return False if layer_name in self.quantsim_config.get_layers_to_add_output_quantizers(): return True return current_value def _add_quantization_wrappers(self, module, prefix=""): if self.__is_quantized_module(module): return for module_name, module_ref in module.named_children(): full_name = f"{prefix}.{module_name}" if prefix else module_name self.logger.info("nn.Module found : %s", module_ref) print("nn.Module found : %s", module_ref) if self.__is_quantizable_module(module_ref) and is_leaf_module(module_ref): quantized_module = self._create_quantizer_module(module_ref, full_name) if not quantized_module: self.logger.info(f"Please register {full_name}") continue setattr(module, module_name, quantized_module) else: self._add_quantization_wrappers(module_ref, prefix=full_name) def _create_quantizer_module( self, module_to_quantize: torch.nn.Module, module_name: str ): param_per_channel = get_task_configs( self.config, "ptq", "parameter_per_channel", False ) act_per_channel = get_task_configs( self.config, "ptq", "activation_per_channel", False ) param_per_tensor = get_task_configs( self.config, "ptq", "parameter_per_tensor", True ) act_per_tensor = get_task_configs( self.config, "ptq", "activation_per_tensor", True ) param_is_symmetric = get_task_configs( self.config, "ptq", "parameter_is_symmetric", True ) act_is_symmetric = get_task_configs( self.config, "ptq", "activation_is_symmetric", False ) global_activation_dtype = get_task_configs( self.config, "ptq", "activation_dtype", "int4" ) global_param_dtype = get_task_configs( self.config, "ptq", "parameter_dtype", "int4" ) global_activation_observer = get_task_configs( self.config, "ptq", "activation_observer" ) global_weight_observer = get_task_configs(self.config, "ptq", "weight_observer") quantizer = ModuleReplacements.get_replacement(module_to_quantize) if not quantizer: self.logger.info(f"Please register {type(module_to_quantize)}") return None # registering parameter and activation dtype setattr(module_to_quantize, "activation_dtype", global_activation_dtype) setattr(module_to_quantize, "parameter_dtype", global_param_dtype) setattr(module_to_quantize, "parameter_observer", global_weight_observer) setattr(module_to_quantize, "activation_observer", global_activation_observer) setattr(module_to_quantize, "param_per_channel", param_per_channel) setattr(module_to_quantize, "act_per_channel", act_per_channel) setattr(module_to_quantize, "param_per_tensor", param_per_tensor) setattr(module_to_quantize, "act_per_tensor", act_per_tensor) setattr(module_to_quantize, "param_is_symmetric", param_is_symmetric) setattr(module_to_quantize, "act_is_symmetric", act_is_symmetric) self.logger.info( f"Replacing {type(module_to_quantize)} with {quantizer.__name__} for quantization" ) quantized_module = quantizer( _module_to_wrap=module_to_quantize, # add_input_quantizers=add_input_quantizers, # add_output_quantizers=add_output_quantizers, ) return quantized_module def prepare( self, ): """ Recursively replace every weight-bearing layer with a QuantWrapper. Args: layer_types: if provided, only wrap these types; else wrap all with 'weight'. model: internal use for recursion (initially None → uses self._model). """ self.logger.info("=" * 60) self.logger.info("Preparing model for QFT (Quantization-Aware Fine-Tuning)") self.logger.info("=" * 60) self._model.eval() self.fuse_all_conv_bn(self._model) # try: self.traced_graph = fx.symbolic_trace(self._model) self.graphutils = ModelGraphUtils(self._model, self.traced_graph) self._model = prepare_model(self._model) self.graphutils.update_model(self._model) self._add_quantization_wrappers(self._model) self.graphutils.update_graph(self._model) self.apply_io_quantizer_flags_to_graph(self._model) self.logger.info(self._model) # except Exception as e: # print(e) # self.graphutils = ModelGraphUtils(self._model, None) # self.logger.info("Model is not graph tracable. Cannot replace math ops.") # self._add_quantization_wrappers(self._model) self.logger.info("Model after preparing") self.logger.info(self._model) print_quant_summary(self._model) return self._model def convert(self, model) -> None: """ Convert the model to a quantized model. This method is called after the training and preparation steps. """ for name, child in list(model.named_children()): if isinstance(child, QuantizationMixin): self.logger.info(f"Replacing Quantized module: {name}") module: QuantizationMixin = child scale_fp = module.get_scale_fp() wrapped_module: nn.Module = module._module_to_wrap for k, v in scale_fp.items(): wrapped_module.register_buffer(f"{k}_scale", v["scale"]) wrapped_module.register_buffer(f"{k}_zero_point", v["zero_point"]) setattr(wrapped_module, f"{k}_scale", v["scale"]) setattr(wrapped_module, f"{k}_zero_point", v["zero_point"]) setattr(model, name, wrapped_module) continue self.convert(child) def export_model(self, model: torch.nn.Module, task_type: str) -> None: """ Export the quantized model to a format suitable for deployment. """ self.model_name = get_model_configs(self.config)["name"] export_format = self.config["export"]["format"] output_dir = self.config["export"]["output_dir"] + "/" + self.model_name if export_format == "onnx": export_model = model.apply(torch.ao.quantization.disable_observer) export_model.cpu() if not os.path.exists(output_dir): os.makedirs(output_dir, exist_ok=True) param_dtype = get_task_configs( self.config, task_type, "parameter_dtype", "int4" ) opset = get_task_configs(self.config, "export", "opset", 13) attribute = get_task_configs(self.config, "export", "attribute", None) if attribute is not None: if not hasattr(export_model, attribute): raise ValueError(f"Model has no attribute '{attribute}' for export") export_model = getattr(export_model, attribute) onnx_file = ( f"{self.model_name}_quantized_model_{task_type}_{param_dtype}.onnx" ) output_path = os.path.join(output_dir, onnx_file) self.logger.info("=" * 60) self.logger.info("Exporting the quantized model") self.logger.info("=" * 60) self.logger.info(f"Model Name : {self.model_name}") self.logger.info(f"Export Format : {export_format}") self.logger.info(f"Output Path : {output_path}") self.logger.info("=" * 60) try: torch.onnx.export( export_model, torch.randn(self._input_shape), # type: ignore output_path, export_params=True, opset_version=opset, do_constant_folding=True, input_names=["input_image"], output_names=["output"], ) self.logger.info("Model export completed successfully.") except Exception as e: self.logger.error("Model export failed.") self.logger.exception(e) def __extract_encoding(self, module: CustomFakeQuantize) -> Dict: scale = module.scale zero_point = module.zero_point quant_min = module.quant_min quant_max = module.quant_max qscheme = module.qscheme dtype = str(module.dtype) bitwidth = get_bitwidth_from_string_dtype(dtype) is_symmetric = qscheme in [ torch.per_tensor_symmetric, torch.per_channel_symmetric, ] if is_symmetric: encoding_min = -scale * ((quant_max - quant_min) / 2) encoding_max = scale * ((quant_max - quant_min) / 2) else: encoding_min = scale * (quant_min - zero_point) encoding_max = scale * (quant_max - zero_point) base_info = { "bitwidth": bitwidth, "quant_min": quant_min, "quant_max": quant_max, "qscheme": str(qscheme), "dtype": dtype, "is_symmetric": is_symmetric, } if scale.numel() == 1: return { **base_info, "encodings": [ { "scale": scale.item(), "offset": zero_point.item(), "min": encoding_min.item(), "max": encoding_max.item(), } ], } else: encodings = [] for i in range(scale.numel()): encodings.append( { "scale": scale[i].item(), "offset": zero_point[i].item(), "min": encoding_min[i].item(), "max": encoding_max[i].item(), } ) return {**base_info, "encodings": encodings} def __get_quant_min_max(self, dtype): dtype = dtype.lower() if dtype == "int4": return -8, 7, torch.qint8 elif dtype == "uint4": return 0, 15, torch.quint8 elif dtype == "int8": return -128, 127, torch.qint8 elif dtype == "uint8": return 0, 255, torch.quint8 elif dtype == "int16": return -32768, 32767, torch.qint32 # return 0, 65535, torch.qint32 else: raise ValueError(f"Unsupported dtype: {dtype}") def __get_encodings_from_model(self, model, encoding_dict=None) -> Dict[str, Dict]: """ Extracts quantization parameters (scale, zero_point, min, max, etc.) from a model prepared using torch.ao.quantization.prepare_qat. Returns a hierarchical dictionary of encodings per FakeQuantize module. """ encoding_dict = {} for name, child in model.named_modules(): if not name: continue if "quant" in name: continue onnx_path = build_onnx_path_from_torch_name(name) cls_name = normalize_class_name(child.__class__.__name__) output_name = build_clean_onnx_path(f"{onnx_path}/{cls_name}") output_name = output_name.replace("Quantized", "") if "module_" in output_name: output_name.replace("module_", "") if isinstance(child, QuantizationMixin): data = {} for sub_name, sub_module in child.named_modules(): if isinstance(sub_module, (CustomFakeQuantize)): encoding_info = self.__extract_encoding(sub_module) data[f"{sub_name.replace('_quantizers','')}"] = encoding_info encoding_dict[output_name] = data return encoding_dict def generate_embeddings(self, attribute): """ Iterates through model modules, collects quantization encodings from QuantizationMixin modules, and writes them to a JSON file as a list. """ output_dir = ( self.config["export"]["output_dir"] + "/" + get_model_configs(self.config, "name") ) # if attribute is not None: # if not hasattr(model, attribute): # raise ValueError(f"Model has no attribute '{attribute}' for export") # model = getattr(model, attribute) all_encodings = self.__get_encodings_from_model( self._model # if isinstance(attribute, str) and not hasattr(self._model, attribute) # else isinstance(attribute, str) and getattr(self._model, attribute) ) task_type = ( "qft" if get_model_configs(self.config, "qft") else "ptq" if get_model_configs(self.config, "ptq") else "pruning" ) output_path = os.path.join( output_dir, f'{get_model_configs(self.config, "name")}_{task_type}_quantization_encodings.json', ) os.makedirs(output_dir, exist_ok=True) with open(output_path, "w") as f: json.dump(all_encodings, f, indent=4) print(f"Quantization encodings saved to: {output_path}") Update all method according to my quantsimparse ignore aimet i have sent one quantsimparser class right some methods that you have used here is wrong
Comments