#!/usr/bin/env python
# ******************************************************************************
# Copyright 2025 Brainchip Holdings Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
__all__ = ["quantize"]

import re
import warnings
import tqdm
from quantizeml.onnx_support.quantization.calibration import calibrate

from ..compatibility_info import ModelCompatibilityInfo
from ..tools import (ONNXExtractorModel, ensure_model_type, is_quantized, replace_graph,
                     convert_model_to)
from ..transforms import remove_pointless_quantizers
from .sequences import search_quantizable_sequences, _infer_partial_io
from .core import quantize_calibrated, _WARNINGS_TO_BE_HANDLE_AS_ERRORS


def _find_source_error(error_warning_list):
    for obj in error_warning_list:
        if isinstance(obj, Exception):
            return obj
        elif (isinstance(obj, warnings.WarningMessage) and
                any(x in str(obj.message) for x in _WARNINGS_TO_BE_HANDLE_AS_ERRORS)):
            # Partial quantization occurred.
            return obj
    # No error source found.
    return None


def _find_tensors_by_error_msg(error_msg, model, iname_to_nodes):
    assert all(node.name != "" for node in model.nodes()), "node names are required."
    remaining_nodes = None
    faulty_nodes = []

    # Computes the faulty nodes.
    if isinstance(error_msg, warnings.WarningMessage):
        error_msg = str(error_msg.message)
        if "Impossible to quantize" in error_msg:
            # There was a list of nodes that did not allow to finalize the model quantization:
            # (e.g. [print(node_proto1), print(node_proto2), ...]).
            # Search for incompatible nodes in the list.
            f_error_msg = re.sub(r'attribute\s*{[^}]*}', '', error_msg, flags=re.DOTALL)
            node_names = re.findall(r'name:\s*"([^"]+)"', f_error_msg)
            faulty_nodes = [node for node in model.nodes() if node.name in node_names]
        # In other cases, we cannot identify the faulty nodes.
    else:
        warnings.warn(f"Node was not found in error (not handled yet): {error_msg}.", stacklevel=2)
        error_msg = str(error_msg)
        # Assume the quantization error was due to the first node.
        faulty_nodes = [model.nodes()[0]]

    # Search for nodes that were not quantized and are not in the faulty list.
    remaining_nodes = [n for n in model.nodes() if not (is_quantized(n) or n in faulty_nodes)]

    # Deduce the split tensors in order to quantize the remaining nodes.
    remaining_inputs, remaining_outputs = _infer_partial_io(
        remaining_nodes,
        iname_to_nodes,
        exclude=list(model.get_initializer_name_set()),
        graph_outputs={x.name for x in model.output})
    return faulty_nodes, remaining_inputs, remaining_outputs, error_msg


@ensure_model_type
def quantize_sequential(model,
                        compatibility_info,
                        tensors_range,
                        input_dtype='int8',
                        iname_to_nodes=None,
                        bar=True):
    """Quantize a sequential ONNX model, handling partial quantization and errors as needed.

    Note this function expects a sanitized model.

    Args:
        model (Any): the ONNX model to quantize.
        compatibility_info (ModelCompatibilityInfo): an existing ModelCompatibilityInfo
            object to accumulate incompatibility information during quantization.
        tensors_range (TensorsData): calibration ranges. If not provided,
            they are computed by increasing the runtime.
        input_dtype (np.dtype or str, optional): expected model input format. If given as a string,
            should follow numpy string type requirements. Defaults to 'int8'.
        iname_to_nodes (dict, optional): mapping from input tensor names to nodes
            consuming them. If None, it is computed. Defaults to None.
        bar (bool, optional): whether to show a progress bar. Defaults to True.

    Returns:
        Any: the quantized model.
    """
    if len(model.nodes()) == 0:
        # Return the empty graph.
        return model

    if iname_to_nodes is None:
        iname_to_nodes = model.input_name_to_nodes()

    # Search quantizable sequences in the model.
    sequences = search_quantizable_sequences(model, compatibility_info)

    # Analyze error source when quantizing each sequence.
    qparams = dict(input_dtype=input_dtype, tensors_range=tensors_range)
    qmodel = model.clone()
    for base_model in tqdm.tqdm(sequences, desc="Quantizing", disable=not bar):
        input_names = [x.name for x in base_model.graph.input]
        output_names = [x.name for x in base_model.graph.output]
        # Use input_dtype only if no node has been quantized yet.
        if qparams['input_dtype'] != 'int8' and any(is_quantized(node) for node in qmodel.nodes()):
            qparams['input_dtype'] = 'int8'
        try:
            # Try to quantize the sequences, triggering partial quantization as an error.
            qbase_model = convert_model_to(base_model, ONNXExtractorModel)
            with warnings.catch_warnings(record=True) as errors_queue:
                # Quantize model with tensors_range.
                qbase_model = quantize_calibrated(qbase_model,
                                                  ensure_fully_quantized=False,
                                                  **qparams)
        except Exception as e:
            # Append the error into the queue. Please note that we give higher priority to warnings,
            # since they occur before errors.
            errors_queue.append(e)

        # If there is any source of error, continue the algorithm.
        if source_error := _find_source_error(errors_queue):
            # Try to parse the node which produces the error.
            faulty_nodes, i_names, o_names, source_error = _find_tensors_by_error_msg(
                source_error, qbase_model, iname_to_nodes)
            # Add nodes to info with its error message.
            if len(faulty_nodes) and compatibility_info is not None:
                compatibility_info._set_incompatibility(node_sequence=faulty_nodes,
                                                        stage="Quantization",
                                                        faulty_node=faulty_nodes[0].name,
                                                        reason=source_error)
            # Quantize remaining nodes.
            # Note input dtype changes if there are some quantized node.
            head_model = qbase_model.extract_model(input_names=i_names, output_names=o_names)
            if qparams["input_dtype"] != 'int8' and any(is_quantized(node)
                                                        for node in qbase_model.nodes()):
                qparams["input_dtype"] = 'int8'
            qhead_model = quantize_sequential(head_model,
                                              compatibility_info,
                                              bar=False,
                                              iname_to_nodes=iname_to_nodes,
                                              **qparams)
            replace_graph(qbase_model, qhead_model, from_tensors=i_names, until_tensors=o_names)

        # Replace the quantized sub-model into the main model.
        replace_graph(qmodel, qbase_model, input_names, output_names)
    return qmodel


@ensure_model_type
def quantize(model, input_dtype='uint8', samples=None, num_samples=1):
    """Quantizes an ONNX model, handling partial quantization and quantization errors as needed.

    This function attempts to quantize the entire model in one pass. If quantization fails
    due to unsupported or unquantizable nodes, it analyzes the error, splits the model at
    the problematic nodes, and recursively quantizes each sub-model. The quantized sub-models
    are then merged back into the original model graph. This approach ensures that all
    quantizable parts of the model are quantized, while gracefully handling sections that
    cannot be quantized.

    Note that this function expects a sanitized model.

    Args:
        model (Any): the ONNX model to quantize.
        input_dtype (np.dtype or str, optional): expected model input format. If given as a string,
            should follow numpy string type requirements. Defaults to 'uint8'.
        samples (list of numpy arrays, optional): List of input samples to use for
            calibration. If not provided, random samples will be generated. Defaults
            to None.
        num_samples (int, optional): Number of samples to use for calibration.
            Defaults to 1.

    Returns:
        Any, ModelCompatibilityInfo: the quantized model and the incompatibilites (optional).
    """
    # Prevent requantization
    if any(is_quantized(node) for node in model.nodes()):
        raise ValueError("Requantizing a model is not supported. "
                         "Please quantize the original float model directly.")
    compatibility_info = ModelCompatibilityInfo(model.model)
    # Compute tensors range
    tensors_range = calibrate(model.model, samples, num_samples)
    # Quantize the model.
    qmodel = quantize_sequential(model, compatibility_info, tensors_range, input_dtype)
    # Remove pointless quantizers.
    remove_pointless_quantizers(qmodel)
    return qmodel, compatibility_info
