import json
import os
from typing import Dict
import torch
import torch.fx as fx
import torch.nn as nn
from core.frameworks.pytorch.quantization.sdk.graph_utils import ModelGraphUtils
from core.frameworks.pytorch.quantization.sdk.model_preparer import (
_get_module_for_dotted_name,
prepare_model,
)
from core.frameworks.pytorch.quantization.sdk.modules.fake_quantize import (
FakeQuantize as CustomFakeQuantize,
)
from core.frameworks.pytorch.quantization.sdk.modules.module_replacement_registry import (
ModuleReplacements,
)
from core.frameworks.pytorch.quantization.sdk.modules.quantized_modules import (
QuantizationMixin,
)
from core.frameworks.pytorch.quantization.sdk.quantization_utils import (
build_clean_onnx_path,
build_onnx_path_from_torch_name,
get_bitwidth_from_string_dtype,
normalize_class_name,
print_quant_summary,
)
from core.frameworks.pytorch.utils.utils import is_leaf_module
from logger.logger_registry import get_logger
from utils.config_utils import (
get_model_configs,
get_task_configs,
)
from utils.helpers import set_nested_attr
from .modules.function_modules import REPLACEMENTS
from .quantsim_config.quantsim_parser import QuantSimConfigParser
logger = get_logger("main")
class QuantizerEnginer:
def __init__(
self,
model: torch.nn.Module,
configs,
device: torch.device,
):
self._model: torch.nn.Module = model
self.config = configs
self.logger = get_logger("main")
self.device = device
self.task_type = (
"qft" if self.config["qft"] else "ptq" if self.config["ptq"] else "pruning"
)
self.quantsim_config = QuantSimConfigParser()
def update_model(self, model: torch.nn.Module):
self._model = model
def fuse_all_conv_bn(self, module):
fusion_list = []
module_list = list(module.named_modules())
for (name1, mod1), (name2, mod2) in zip(module_list, module_list[1:]):
if isinstance(mod1, torch.nn.Conv2d) and isinstance(
mod2, torch.nn.BatchNorm2d
):
fusion_list.append([name1, name2])
if fusion_list:
torch.ao.quantization.fuse_modules(module, fusion_list, inplace=True)
def _check_super_groups_module(self, layer_name: str) -> bool:
super_groups = self.quantsim_config.get_super_groups_ops()
next_node_name = self.graphutils.get_next_non_identity_module(layer_name)
if not next_node_name:
return self.quantsim_config.is_default_ops_output_quantized()
layer_name = self.graphutils.get_normalized_layer_name(layer_name)
next_node_name = self.graphutils.get_normalized_layer_name(next_node_name)
if next_node_name:
next_node_name_lower = next_node_name.lower()
node_name_lower = layer_name.lower()
for super_group in super_groups:
super_group_first = super_group[0].lower()
super_group_second = super_group[1].lower()
if (
node_name_lower in super_group_first
and next_node_name_lower in super_group_second
):
return False
elif next_node_name_lower in super_group_second:
return True
return self.quantsim_config.is_default_ops_output_quantized()
def apply_io_quantizer_flags_to_graph(
self, graph: torch.fx.GraphModule
) -> dict[str, tuple[bool, bool]]:
"""
Iterate over all nodes in the graph and determine IO quantizer flags.
Returns:
A dictionary mapping node names to (add_input_quantizer, add_output_quantizer) flags.
"""
io_flags = {}
for node in graph.graph.nodes:
if node.op == "call_module":
flags = self.__get_io_quantizer_flags(graph, node, None)
io_flags[node.name] = flags
print(
f"[IO Quantizer Flags] {node.name}: Input={flags[0]}, Output={flags[1]}"
)
module: QuantizationMixin = _get_module_for_dotted_name(
graph, node.target
)
if not isinstance(module, QuantizationMixin):
continue
if not flags[0]:
module.disable_input_quantization()
if not flags[1]:
module.disable_output_quantization()
if flags[0]:
module.enable_input_quantization()
if flags[1]:
module.enable_output_quantization()
return io_flags
def _check_super_groups_node(
self,
layer_name: str,
node: torch.fx.Node,
next_node: torch.fx.Node,
next_node_name: str,
) -> bool:
super_groups = self.quantsim_config.get_super_groups_ops()
for super_group in super_groups:
if (
layer_name in super_group[0].lower()
and next_node
and next_node_name in super_group[1].lower()
):
print("#" * 20, layer_name, next_node_name, "#" * 20)
return False
elif layer_name in super_group[1].lower():
return True
return True
def __is_quantizable_module(self, module: torch.nn.Module):
return ModuleReplacements.get_replacement(module) is not None
@classmethod
def __is_quantized_module(cls, module: torch.nn.Module):
return isinstance(module, QuantizationMixin)
def _normalized_op_name(self, module: torch.nn.Module) -> str:
"""Map a module to the normalized op name used in your config."""
return self.graphutils.get_normalized_layer_name(module)
def _match_pattern_from(
self,
graph: torch.fx.GraphModule,
start_node: torch.fx.Node,
pattern_ops: list[str],
):
"""
Try to match a supergroup pattern starting at start_node.
Returns list of nodes if matched, else None.
"""
if start_node.op != "call_module":
return None
try:
first_module = _get_module_for_dotted_name(graph, start_node.target)
except Exception:
return None
first_name = self._normalized_op_name(first_module)
if first_name != pattern_ops[0]:
return None
matched_nodes = [start_node]
curr = start_node
for expected in pattern_ops[1:]:
next_node = self.graphutils.get_next_non_identity_node(curr)
if not next_node or next_node.op != "call_module":
return None
try:
next_module = _get_module_for_dotted_name(graph, next_node.target)
except Exception:
return None
next_name = self._normalized_op_name(next_module)
if next_name != expected:
return None
matched_nodes.append(next_node)
curr = next_node
return matched_nodes
def _collect_supergroup_patterns(self):
"""
Read patterns from your quantsim config.
Expected structure: [["Conv", "BatchNorm", "Relu"], ["MatMul", "Add"]]
"""
cfg = self.quantsim_config.get_model_quantization_config()
patterns = cfg.get("supergroups", [])
cleaned = []
for pat in patterns:
if isinstance(pat, (list, tuple)) and len(pat) >= 2:
cleaned.append([str(x) for x in pat])
return cleaned
def _apply_super_groups_config(self, graph: torch.fx.GraphModule):
"""
Find and apply super-group quantization configuration.
Works for groups of any length >= 2. Prevents overlaps.
"""
patterns = self._collect_supergroup_patterns()
self._supergroup_members = {}
claimed = set()
group_id = 0
for node in graph.graph.nodes:
if node.op != "call_module" or node in claimed:
continue
for pat in patterns:
match = self._match_pattern_from(graph, node, pat)
if not match:
continue
if any(n in claimed for n in match):
continue
size = len(match)
for idx, member in enumerate(match):
self._supergroup_members[member] = (group_id, idx, size)
claimed.add(member)
# Gather actual modules
modules = []
for mnode in match:
m = _get_module_for_dotted_name(graph, mnode.target)
modules.append(m)
# Apply quantizer sharing
self._apply_super_group_action_general(modules)
if getattr(self, "verbose", False):
names = [self._normalized_op_name(m) for m in modules]
print(f"[SuperGroup] id={group_id} matched: {' -> '.join(names)}")
group_id += 1
break
def _belongs_to_super_group(self, node: torch.fx.Node) -> bool:
return hasattr(self, "_supergroup_members") and node in self._supergroup_members
def _supergroup_position(self, node: torch.fx.Node):
"""Return (group_id, idx, size) if node is in a supergroup, else None."""
if self._belongs_to_super_group(node):
return self._supergroup_members[node]
return None
def _apply_super_group_action_general(self, modules: list):
n = len(modules)
if n < 2:
return
last_module = modules[-1]
if not hasattr(last_module, "output_quantizer"):
return
shared_output_q = last_module.output_quantizer
for m in modules[:-1]:
if hasattr(m, "output_quantizer"):
m.output_quantizer = shared_output_q
def __get_io_quantizer_flags(
self, graph, node: torch.fx.Node, layer_name: str | None
) -> tuple[bool, bool]:
"""
Decide whether to add input/output quantizers for a node,
respecting model IO policy and per-layer overrides.
"""
layer_name = layer_name or self.graphutils.get_normalized_layer_name(node)
model_quant_config = self.quantsim_config.get_model_quantization_config()
add_input_quantizer = self.quantsim_config.is_default_ops_input_quantized()
add_output_quantizer = self.quantsim_config.is_default_ops_output_quantized()
if self.graphutils._is_first(node, layer_name):
add_input_quantizer = model_quant_config.get("input_quantized", False)
if self.graphutils._is_last(node, layer_name):
add_output_quantizer = model_quant_config.get("output_quantized", False)
# No need for explicit supergroup override:
# all sharing is already handled in _apply_super_group_action_general.
add_input_quantizer = self._apply_layer_override(
layer_name, add_input_quantizer, is_input=True
)
add_output_quantizer = self._apply_layer_override(
layer_name, add_output_quantizer, is_input=False
)
_module = _get_module_for_dotted_name(graph, node.target)
return add_input_quantizer, add_output_quantizer
def _apply_layer_override(
self, layer_name: str, current_value: bool, is_input: bool
) -> bool:
if is_input:
if (
layer_name
in self.quantsim_config.get_layers_to_skip_from_input_quantizers()
):
return False
if layer_name in self.quantsim_config.get_layers_to_add_input_quantizers():
return True
else:
if (
layer_name
in self.quantsim_config.get_layers_to_skip_from_output_quantizers()
):
return False
if layer_name in self.quantsim_config.get_layers_to_add_output_quantizers():
return True
return current_value
def _add_quantization_wrappers(self, module, prefix=""):
if self.__is_quantized_module(module):
return
for module_name, module_ref in module.named_children():
full_name = f"{prefix}.{module_name}" if prefix else module_name
self.logger.info("nn.Module found : %s", module_ref)
print("nn.Module found : %s", module_ref)
if self.__is_quantizable_module(module_ref) and is_leaf_module(module_ref):
quantized_module = self._create_quantizer_module(module_ref, full_name)
if not quantized_module:
self.logger.info(f"Please register {full_name}")
continue
setattr(module, module_name, quantized_module)
else:
self._add_quantization_wrappers(module_ref, prefix=full_name)
def _create_quantizer_module(
self, module_to_quantize: torch.nn.Module, module_name: str
):
param_per_channel = get_task_configs(
self.config, "ptq", "parameter_per_channel", False
)
act_per_channel = get_task_configs(
self.config, "ptq", "activation_per_channel", False
)
param_per_tensor = get_task_configs(
self.config, "ptq", "parameter_per_tensor", True
)
act_per_tensor = get_task_configs(
self.config, "ptq", "activation_per_tensor", True
)
param_is_symmetric = get_task_configs(
self.config, "ptq", "parameter_is_symmetric", True
)
act_is_symmetric = get_task_configs(
self.config, "ptq", "activation_is_symmetric", False
)
global_activation_dtype = get_task_configs(
self.config, "ptq", "activation_dtype", "int4"
)
global_param_dtype = get_task_configs(
self.config, "ptq", "parameter_dtype", "int4"
)
global_activation_observer = get_task_configs(
self.config, "ptq", "activation_observer"
)
global_weight_observer = get_task_configs(self.config, "ptq", "weight_observer")
quantizer = ModuleReplacements.get_replacement(module_to_quantize)
if not quantizer:
self.logger.info(f"Please register {type(module_to_quantize)}")
return None
# registering parameter and activation dtype
setattr(module_to_quantize, "activation_dtype", global_activation_dtype)
setattr(module_to_quantize, "parameter_dtype", global_param_dtype)
setattr(module_to_quantize, "parameter_observer", global_weight_observer)
setattr(module_to_quantize, "activation_observer", global_activation_observer)
setattr(module_to_quantize, "param_per_channel", param_per_channel)
setattr(module_to_quantize, "act_per_channel", act_per_channel)
setattr(module_to_quantize, "param_per_tensor", param_per_tensor)
setattr(module_to_quantize, "act_per_tensor", act_per_tensor)
setattr(module_to_quantize, "param_is_symmetric", param_is_symmetric)
setattr(module_to_quantize, "act_is_symmetric", act_is_symmetric)
self.logger.info(
f"Replacing {type(module_to_quantize)} with {quantizer.__name__} for quantization"
)
quantized_module = quantizer(
_module_to_wrap=module_to_quantize,
# add_input_quantizers=add_input_quantizers,
# add_output_quantizers=add_output_quantizers,
)
return quantized_module
def prepare(
self,
):
"""
Recursively replace every weight-bearing layer with a QuantWrapper.
Args:
layer_types: if provided, only wrap these types; else wrap all with 'weight'.
model: internal use for recursion (initially None → uses self._model).
"""
self.logger.info("=" * 60)
self.logger.info("Preparing model for QFT (Quantization-Aware Fine-Tuning)")
self.logger.info("=" * 60)
self._model.eval()
self.fuse_all_conv_bn(self._model)
# try:
self.traced_graph = fx.symbolic_trace(self._model)
self.graphutils = ModelGraphUtils(self._model, self.traced_graph)
self._model = prepare_model(self._model)
self.graphutils.update_model(self._model)
self._add_quantization_wrappers(self._model)
self.graphutils.update_graph(self._model)
self.apply_io_quantizer_flags_to_graph(self._model)
self.logger.info(self._model)
# except Exception as e:
# print(e)
# self.graphutils = ModelGraphUtils(self._model, None)
# self.logger.info("Model is not graph tracable. Cannot replace math ops.")
# self._add_quantization_wrappers(self._model)
self.logger.info("Model after preparing")
self.logger.info(self._model)
print_quant_summary(self._model)
return self._model
def convert(self, model) -> None:
"""
Convert the model to a quantized model.
This method is called after the training and preparation steps.
"""
for name, child in list(model.named_children()):
if isinstance(child, QuantizationMixin):
self.logger.info(f"Replacing Quantized module: {name}")
module: QuantizationMixin = child
scale_fp = module.get_scale_fp()
wrapped_module: nn.Module = module._module_to_wrap
for k, v in scale_fp.items():
wrapped_module.register_buffer(f"{k}_scale", v["scale"])
wrapped_module.register_buffer(f"{k}_zero_point", v["zero_point"])
setattr(wrapped_module, f"{k}_scale", v["scale"])
setattr(wrapped_module, f"{k}_zero_point", v["zero_point"])
setattr(model, name, wrapped_module)
continue
self.convert(child)
def export_model(self, model: torch.nn.Module, task_type: str) -> None:
"""
Export the quantized model to a format suitable for deployment.
"""
self.model_name = get_model_configs(self.config)["name"]
export_format = self.config["export"]["format"]
output_dir = self.config["export"]["output_dir"] + "/" + self.model_name
if export_format == "onnx":
export_model = model.apply(torch.ao.quantization.disable_observer)
export_model.cpu()
if not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
param_dtype = get_task_configs(
self.config, task_type, "parameter_dtype", "int4"
)
opset = get_task_configs(self.config, "export", "opset", 13)
attribute = get_task_configs(self.config, "export", "attribute", None)
if attribute is not None:
if not hasattr(export_model, attribute):
raise ValueError(f"Model has no attribute '{attribute}' for export")
export_model = getattr(export_model, attribute)
onnx_file = (
f"{self.model_name}_quantized_model_{task_type}_{param_dtype}.onnx"
)
output_path = os.path.join(output_dir, onnx_file)
self.logger.info("=" * 60)
self.logger.info("Exporting the quantized model")
self.logger.info("=" * 60)
self.logger.info(f"Model Name : {self.model_name}")
self.logger.info(f"Export Format : {export_format}")
self.logger.info(f"Output Path : {output_path}")
self.logger.info("=" * 60)
try:
torch.onnx.export(
export_model,
torch.randn(self._input_shape), # type: ignore
output_path,
export_params=True,
opset_version=opset,
do_constant_folding=True,
input_names=["input_image"],
output_names=["output"],
)
self.logger.info("Model export completed successfully.")
except Exception as e:
self.logger.error("Model export failed.")
self.logger.exception(e)
def __extract_encoding(self, module: CustomFakeQuantize) -> Dict:
scale = module.scale
zero_point = module.zero_point
quant_min = module.quant_min
quant_max = module.quant_max
qscheme = module.qscheme
dtype = str(module.dtype)
bitwidth = get_bitwidth_from_string_dtype(dtype)
is_symmetric = qscheme in [
torch.per_tensor_symmetric,
torch.per_channel_symmetric,
]
if is_symmetric:
encoding_min = -scale * ((quant_max - quant_min) / 2)
encoding_max = scale * ((quant_max - quant_min) / 2)
else:
encoding_min = scale * (quant_min - zero_point)
encoding_max = scale * (quant_max - zero_point)
base_info = {
"bitwidth": bitwidth,
"quant_min": quant_min,
"quant_max": quant_max,
"qscheme": str(qscheme),
"dtype": dtype,
"is_symmetric": is_symmetric,
}
if scale.numel() == 1:
return {
**base_info,
"encodings": [
{
"scale": scale.item(),
"offset": zero_point.item(),
"min": encoding_min.item(),
"max": encoding_max.item(),
}
],
}
else:
encodings = []
for i in range(scale.numel()):
encodings.append(
{
"scale": scale[i].item(),
"offset": zero_point[i].item(),
"min": encoding_min[i].item(),
"max": encoding_max[i].item(),
}
)
return {**base_info, "encodings": encodings}
def __get_quant_min_max(self, dtype):
dtype = dtype.lower()
if dtype == "int4":
return -8, 7, torch.qint8
elif dtype == "uint4":
return 0, 15, torch.quint8
elif dtype == "int8":
return -128, 127, torch.qint8
elif dtype == "uint8":
return 0, 255, torch.quint8
elif dtype == "int16":
return -32768, 32767, torch.qint32
# return 0, 65535, torch.qint32
else:
raise ValueError(f"Unsupported dtype: {dtype}")
def __get_encodings_from_model(self, model, encoding_dict=None) -> Dict[str, Dict]:
"""
Extracts quantization parameters (scale, zero_point, min, max, etc.)
from a model prepared using torch.ao.quantization.prepare_qat.
Returns a hierarchical dictionary of encodings per FakeQuantize module.
"""
encoding_dict = {}
for name, child in model.named_modules():
if not name:
continue
if "quant" in name:
continue
onnx_path = build_onnx_path_from_torch_name(name)
cls_name = normalize_class_name(child.__class__.__name__)
output_name = build_clean_onnx_path(f"{onnx_path}/{cls_name}")
output_name = output_name.replace("Quantized", "")
if "module_" in output_name:
output_name.replace("module_", "")
if isinstance(child, QuantizationMixin):
data = {}
for sub_name, sub_module in child.named_modules():
if isinstance(sub_module, (CustomFakeQuantize)):
encoding_info = self.__extract_encoding(sub_module)
data[f"{sub_name.replace('_quantizers','')}"] = encoding_info
encoding_dict[output_name] = data
return encoding_dict
def generate_embeddings(self, attribute):
"""
Iterates through model modules, collects quantization encodings from QuantizationMixin modules,
and writes them to a JSON file as a list.
"""
output_dir = (
self.config["export"]["output_dir"]
+ "/"
+ get_model_configs(self.config, "name")
)
# if attribute is not None:
# if not hasattr(model, attribute):
# raise ValueError(f"Model has no attribute '{attribute}' for export")
# model = getattr(model, attribute)
all_encodings = self.__get_encodings_from_model(
self._model
# if isinstance(attribute, str) and not hasattr(self._model, attribute)
# else isinstance(attribute, str) and getattr(self._model, attribute)
)
task_type = (
"qft"
if get_model_configs(self.config, "qft")
else "ptq" if get_model_configs(self.config, "ptq") else "pruning"
)
output_path = os.path.join(
output_dir,
f'{get_model_configs(self.config, "name")}_{task_type}_quantization_encodings.json',
)
os.makedirs(output_dir, exist_ok=True)
with open(output_path, "w") as f:
json.dump(all_encodings, f, indent=4)
print(f"Quantization encodings saved to: {output_path}")
Update all method according to my quantsimparse
ignore aimet
i have sent one quantsimparser class right
some methods that you have used here is wrong