#!/usr/bin/env python
# ******************************************************************************
# Copyright 2025 Brainchip Holdings Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
__all__ = []

import io
import warnings
from contextlib import redirect_stdout

import akida
import cnn2snn
from quantizeml.models import QuantizationParams, quantization
from quantizeml.onnx_support.quantization.quantize import quantize_calibrated as qml_quantize

from ..tools import ONNXExtractorModel, ensure_model_type, convert_model_to, rename_tensors

_WARNINGS_TO_BE_HANDLE_AS_ERRORS = ["The following nodes were not quantized",
                                    "Impossible to quantize"]


@ensure_model_type
def quantize_calibrated(model, tensors_range, input_dtype="int8", ensure_fully_quantized=True):
    """Quantizes a model using provided calibration tensor ranges.

    This function replicates `quantizeml.models.quantize()` avoiding sanitize and calibrate steps.

    Args:
        model (Any): the model to quantize. An already sanitized model is expected.
        tensors_range (TensorsData): calibration ranges, typically obtained from a calibration step.
        input_dtype (str, optional): the data type used to quantize the inputs. Defaults to "int8".
        ensure_fully_quantized (bool, optional): If True, raises an error if the model is not
            fully quantized. Otherwise, allows partial quantization. Defaults to True.

    Returns:
        Any: the quantized model.
    """
    # Clone the model because insert_rescaling is performed inplace.
    model = model.clone()
    with redirect_stdout(io.StringIO()), warnings.catch_warnings(record=ensure_fully_quantized):
        if ensure_fully_quantized:
            for message in _WARNINGS_TO_BE_HANDLE_AS_ERRORS:
                warnings.filterwarnings('error', message=message)
        # Quantize with tensors_range and input_dtype.
        with quantization(QuantizationParams(input_dtype=input_dtype)):
            qmodel = qml_quantize(model, tensors_range=tensors_range)
        # Rename dequantizer outputs to match with original tensor names
        # (required when merging float graphs with quantized ones).
        tensor_map = {}
        for node in qmodel.nodes():
            if "Dequantizer" in node.op_type:
                # Rename dequantizer inputs to allow naming the outputs as graph outputs.
                for inp, out in zip(qmodel.get_node_inputs(node), node.output):
                    tensor_map[inp] = f"{inp}/to_dequantize"
                    tensor_map[out] = inp
        rename_tensors(qmodel, tensor_map, inplace=True)
    return convert_model_to(qmodel, new_type=ONNXExtractorModel)


def cnn2snn_convert_and_map(model, device=None, **device_kwargs):
    """Generates an Akida model based on an ONNX quantizeml model.

    Args:
        model (obj:`onnx.ModelProto`): a ONNX model to convert.
        device (akida.Device, optional): the Akida device to map the Akida sub models.
            Defaults to None.
        device_kwargs (dict, optional): parameters for computing device if device = None.
            Defaults to {}.

    Returns:
        akida.Model: the generated Akida model.
    """
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=UserWarning)
        akida_model = cnn2snn.convert(model)

    # Compute device.
    map_on_device = akida.compute_min_device(
        akida_model, **device_kwargs) if device is None else device

    # Map model.
    akida_model.map(map_on_device, mode=akida.mapping.MapMode.Minimal, hw_only=False)

    # Check all sequences were mapped in hardware.
    if sw := [seq for seq in akida_model.sequences if seq.backend != akida.BackendType.Hardware]:
        # Throws error on the first SW sequence in a different format:
        # 1. If it is the first sequence, raise native error.
        if (first_sequence := sw[0]) == akida_model.sequences[0]:
            akida_model.map(map_on_device, mode=akida.MapMode.Minimal, hw_only=True)
        # 2. Show sequence location on the model, allowing to find it by the search node algorithm.
        passes = first_sequence.passes
        raise RuntimeError(f"{passes[0].layers[0].name} -> {passes[-1].layers[-1].name}: "
                           "could not find a compatible sequence.")
    return akida_model
