#!/usr/bin/env python
# ******************************************************************************
# Copyright 2025 Brainchip Holdings Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
__all__ = ["convert_to_hybrid"]

import re
import tqdm
import akida
import warnings

from .sequences import search_convertible_sequences
from .core import cnn2snn_convert_and_map
from ..hybrid_model import HybridModel
from ..compatibility_info import ModelCompatibilityInfo
from ..tools import ensure_model_type, find_no_initializer_inputs, ONNXExtractorModel


def _find_tensors_by_error_msg(error, model):
    assert all(node.name != "" for node in model.nodes()), "node names are required."
    # Search tensor to split model by specific patterns in error message.
    error_msg = str(error)
    # Build a set of all node names in the model.
    node_names = [re.escape(node.name) for node in model.nodes()]
    # Create a regex pattern that matches any node name exactly.
    # Note name could be closed by quotes or ':'.
    pattern = re.compile(r"[\b'\"]?(" + "|".join(node_names) + r")[.\b'\":]+")
    # Find all node in the error message.
    matched_names = set(pattern.findall(error_msg))
    target_nodes = [node for node in model.nodes() if node.name in matched_names]
    if len(target_nodes) > 0:
        # Split model before and after target_node.
        input_names = find_no_initializer_inputs(target_nodes[0], model.model)
        output_names = target_nodes[0].output[:]
    else:
        warnings.warn(f"Node was not found in error: {error_msg}.", stacklevel=2)
        # Split model after first node.
        input_names = output_names = list(model.nodes()[0].output)
    return input_names, output_names


@ensure_model_type
def convert_sequential(qmodel, device=None, hybrid_model=None,
                       compatibility_info=None, **device_kwargs):
    """Recursively converts a quantized ONNX model into a HybridModel.

    This function attempts to convert the provided quantized ONNX model into an Akida model using
    `cnn2snn.convert`. If the conversion fails, it splits the model in several ones and recursively
    tries to convert each sequence independently. Convertible sequences are added as Akida models,
    while non-convertible ones are retained as ONNX models within the HybridModel.

    Args:
        qmodel (Any): the quantized ONNX model to convert.
        device (akida.Device, optional): the Akida device to map the Akida sub models.
            If not present, compute one. Defaults to None.
        hybrid_model (HybridModel, optional): an existing HybridModel to which sub-models
            will be added. If None, a new HybridModel is created. Defaults to None.
        compatibility_info (ModelCompatibilityInfo, optional): an existing ModelCompatibilityInfo
            object to accumulate incompatibility information during conversion and mapping.
            If None, incompatibilities are not recorded. Defaults to None.
        device_kwargs (dict, optional): parameters for computing device if device = None.
            Defaults to {}.

    Returns:
        HybridModel: a hybrid model containing both Akida and ONNX sub-models, with appropriate
            format conversions and connections to preserve the original model's topology.
    """
    def _update_compatibility_info(sub_model, ex):
        if compatibility_info is None:
            # Nothing to record.
            return
        reason, stage = str(ex), "Mapping"
        if "Cannot convert" in reason:
            stage = "Conversion"
            if ex.__cause__ is not None:
                reason = str(ex.__cause__)
        node_sequence = sub_model.nodes()[:]
        compatibility_info._set_incompatibility(node_sequence=node_sequence,
                                                stage=stage,
                                                faulty_node=node_sequence[0].name,
                                                reason=reason)

    if hybrid_model is None:
        hybrid_model = HybridModel(qmodel.model)
    convert_kwargs = {"device": device, **device_kwargs}
    kwargs = {"compatibility_info": compatibility_info, **convert_kwargs}

    # Nothing to convert.
    if len(qmodel.nodes()) == 0:
        return hybrid_model

    # Analyze error source when converting.
    try:
        ak_model = cnn2snn_convert_and_map(qmodel.model, **convert_kwargs)
    except Exception as e:
        # End condition: conversion failure in single-node model.
        if len(qmodel.nodes()) == 1:
            _update_compatibility_info(qmodel, e)
            return hybrid_model
        # Try to parse the node which produces the error.
        split_before, split_after = _find_tensors_by_error_msg(e, qmodel)
        # Split model into three parts:
        model_list = [
            qmodel.extract_model(output_names=split_before),
            qmodel.extract_model(split_before, split_after),
            qmodel.extract_model(input_names=split_after)
        ]
        # Try to convert each sub-model.
        for sub_model in model_list:
            convert_sequential(sub_model, hybrid_model=hybrid_model, **kwargs)
    else:
        # Add the model to the HybridModel.
        hybrid_model._add(ak_model, qmodel.input[0].name, qmodel.output[0].name)
    return hybrid_model


@ensure_model_type
def convert_to_hybrid(qmodel, device=None, *, enable_hwpr=False,
                      sram_size=None, minimal_memory=False):
    """Converts a quantized ONNX model into a HybridModel containing Akida and ONNX sub-models.

    This function splits the input quantized model into convertible sequences, attempts to convert
    and map each fully quantized sequence to an Akida model and adds it to the HybridModel.
    Sequences that cannot be converted and mapped to Akida are added on ONNX domain. The function
    also manages inbound connections between sub-models to preserve the original model's topology.

    When device is not provided, a mimimal device is computed independently for each akida
    compatible model, then a common device is computed and used to map the HybridModel.

    Args:
        qmodel (Any): the quantized ONNX model to convert.
        device (akida.Device, optional): the Akida device to map the Akida sub models.
            Defaults to None.
        enable_hwpr (bool, optional): if True, the device is computed assuming partial
            reconfiguration. Used when `device` is None. Defaults to False.
        sram_size (akida.NP.SramSize, optional): Size of shared SRAM available inside the mesh.
            Ignored when `minimal_memory` is True. Used when `device` is None. Defaults to None.
        minimal_memory (bool, optional): if True, computes and sets the minimal required
            inputs and weights memory for the device.  Used when `device` is None.
            Defaults to False.

    Returns:
        HybridModel: a hybrid model containing Akida and ONNX sub-models.
    """
    if device is not None and (enable_hwpr or sram_size is not None or minimal_memory):
        warnings.warn(
            "When 'device' is provided, the parameters 'enable_hwpr', 'sram_size', "
            "and 'minimal_memory' are ignored. Continuing execution."
        )
    device_params = dict(device=device,
                         enable_hwpr=enable_hwpr,
                         sram_size=sram_size,
                         minimal_memory=minimal_memory)

    # Search convertible sequences.
    sequences = search_convertible_sequences(qmodel, **device_params)

    # Convert convertible sequences into Akida and add them to the HybridModel.
    model = HybridModel(qmodel.model)
    compatibility_info = ModelCompatibilityInfo(qmodel.model)
    for sub_model in tqdm.tqdm(sequences, desc="Converting"):
        sub_model = ONNXExtractorModel(sub_model)
        # Note we transfer extractor to speed runtime.
        sub_model.extractor = qmodel.extractor
        convert_sequential(sub_model,
                           hybrid_model=model,
                           compatibility_info=compatibility_info,
                           **device_params)

    # Post mapping to map all Akida sub models on a common device.
    # Applied only when using minimum device
    if device is None and len(model.akida_models) != 0:
        device = akida.compute_common_device(model.akida_models, enable_hwpr=enable_hwpr)
        model.map(device, mode=akida.MapMode.Minimal)

    return model, compatibility_info
