#!/usr/bin/env python
# ******************************************************************************
# Copyright 2025 Brainchip Holdings Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
__all__ = []

import onnx

from .core import cnn2snn_convert_and_map
from ..tools import (ONNXExtractorModel, ensure_model_type, convert_model_to,
                     find_no_initializer_inputs, is_convertible, is_quantizable)


def _infer_partial_io(nodes, iname_to_nodes, exclude=[], graph_outputs=[]):
    exclude = exclude + [""]
    all_outputs = set(out for n in nodes for out in n.output if out not in exclude)
    all_inputs = set(inp for n in nodes for inp in n.input if inp not in exclude)

    # Compute inputs.
    input_names = [inp for inp in all_inputs if inp not in all_outputs]

    # Compute outputs.
    output_names = []
    for out in all_outputs:
        # Check all 'out' consumers are inside the subgraph or it is a graph output.
        if not all(i_n in nodes for i_n in iname_to_nodes.get(out, [])) or out in graph_outputs:
            output_names.append(out)
    return input_names, output_names


def _collect_subgraph_keys_from(model, root, iname_to_nodes=None, keys_to_exclude={}):
    # Collect the longest possible sequence of nodes in a graph starting from a node.
    # Note that nodes in 'keys_to_exclude' are not traversed.
    if iname_to_nodes is None:
        iname_to_nodes = model.input_name_to_nodes()
    queue = sum([iname_to_nodes[x] for x in root.input], [])
    subgraph_keys = set()
    while queue:
        current = queue.pop(0)
        if (ckey := current.name) in subgraph_keys | keys_to_exclude:
            continue
        subgraph_keys.add(ckey)
        for child in model.get_children(current, iname_to_nodes):
            if child.name in keys_to_exclude | subgraph_keys:
                continue
            queue.append(child)
    return subgraph_keys


@ensure_model_type
def split_model_on_sequences(model, skip_outbounds={}, compatibility_info=None):
    """Split a model into sequences of consecutive nodes based on graph structure.

    This function traverses the model graph and extracts sub-models (sequences) that start from
    nodes connected to graph inputs and extend as far as possible until one of the following
    conditions is met:
      - The outbound node has multiple children (branching).
      - The outbound node is linked to a graph output.
      - The outbound node has multiple outputs.
      - Any child of the outbound node is a merge node (has multiple parents).
      - The target node itself is a merge node.

    Optionally, certain outbound nodes can be skipped using the `skip_outbounds` mapping.

    Args:
        model (Any): the model to split.
        skip_outbounds (dict, optional): a mapping where each key is an output tensor name
            and each value is the name of the output node to use instead of it.
            Note this allows to skip a set of nodes during traversal of graph. Defaults to {}.

    Returns:
        list of onnx.ModelProto: the list of sub-models.
    """
    def _insert_node_in_queue(node):
        if node not in visited_nodes:
            node_queue.append(node)
            visited_nodes.append(node)

    def _is_merge_node(node):
        return len(find_no_initializer_inputs(node, model.model)) > 1

    iname_to_nodes = model.input_name_to_nodes()
    oname_to_node = model.output_name_to_node()
    graph_input_names = [x.name for x in model.input]
    graph_output_names = [x.name for x in model.output]
    if (incompatible_skip := [x for x in skip_outbounds
                              if (x not in iname_to_nodes or x in graph_output_names)]):
        raise ValueError(f"{incompatible_skip} was not found in model or is a graph output.")

    # Adds all nodes linked to an input into the queue.
    node_queue, visited_nodes = [], []
    for node in model.nodes():
        if any(x in graph_input_names for x in node.input):
            _insert_node_in_queue(node)

    # Iter over the queue.
    sequences = []
    while len(node_queue) > 0:
        target_node = outbound_node = node_queue.pop(0)
        # Find the longest sequence of consecutive nodes until one of the following criteria
        # * outbound_node has multiple children
        # * One of outbound_node outputs is linked to the graph output
        # * some child of outbound_node is a merge node
        # * target node is a merge node
        # * childs contains the same operation type (all quantized or not)
        # Note child are skipped if they are in 'skip_outbounds'.
        child_nodes = [target_node]
        while (len(child_nodes) == 1 and
                not any(_is_merge_node(node) for node in child_nodes) and
                len(child_nodes[0].output) == 1 and
                outbound_node.output[0] not in graph_output_names and
                is_convertible(child_nodes[0]) == is_convertible(target_node)):
            outbound_node = child_nodes[0]
            # Update outbound with 'skip_outbounds' info.
            while ((next_outbound := oname_to_node.get(
                    skip_outbounds.get(outbound_node.output[0], ''), False)) and
                    len(next_outbound.output) == 1):
                outbound_node = next_outbound
            # Compute new children for next iter.
            child_nodes = model.get_children(outbound_node, iname_to_nodes)

        # Extract the sequence of consecutive nodes between target_node and outbound_node.
        input_names = find_no_initializer_inputs(target_node, model.model)
        output_names = outbound_node.output
        try:
            seq_model = model.extract_model(input_names, output_names)
            seq_model = convert_model_to(seq_model, new_type=onnx.ModelProto)
            sequences.append(seq_model)
        except Exception as e:
            # Ignore sequence that are not possible to extract.
            if compatibility_info is not None:
                node_sequence = model.extractor._collect_reachable_nodes(input_names, output_names)
                compatibility_info._set_incompatibility(
                    node_sequence=node_sequence,
                    stage="Quantization",
                    faulty_node=node_sequence[0].name,
                    reason=str(e))

        # Append children into queue.
        for node in model.get_children(outbound_node, iname_to_nodes):
            _insert_node_in_queue(node)
    return sequences


@ensure_model_type
def search_cycles_on_model(model):
    """Identify and extract the loop structures (cycles) within an ONNX model graph.

    This function traverses the model graph to find nodes where multiple branches
    converge to the same node, indicating the presence of a loop or cycle
    (complete skip connection).

    For each detected loop, it extracts the corresponding sub-model representing
    the looped sequence. The function also ensures that the extracted loop
    sub-models maintain the correct graph structure and do not produce incomplete
    branches.

    Args:
        model (Any): The ONNX model to extract the sub-models.

    Returns:
        list of onnx.ModelProto: A list of sub-models.
    """
    iname_to_nodes = model.input_name_to_nodes()
    oname_to_node = model.output_name_to_node()
    graph_output_names = [x.name for x in model.output]

    # Add all split nodes in the queue.
    splits = [(node, first_queue) for node in model.nodes()
              if len(first_queue := model.get_children(node, iname_to_nodes)) > 1]

    # Precompute topological order to accelerate ordering algorithm.
    node2index = {node.name: idx for idx, node in enumerate(model.nodes())}

    skip_loops = []
    for split_node, queue in splits:
        # Update the queue with node outbounds until they all converge to one.
        while not all(node == queue[0] for node in queue):
            # Search the split/merge/output outbound.
            current_node = target_node = queue[0]
            while True:
                children = model.get_children(target_node, iname_to_nodes)
                parents = model.get_parents(target_node, oname_to_node)
                if (len(children) != 1 or len(parents) != 1 or
                        children[0].output[0] in graph_output_names):
                    break
                target_node = children[0]

            # Updated current node in the queue with the following criteria:
            # * current is different from target
            # * children when target is a split node
            # * target's outbound when it is a merge node.
            # Note we are able to finish the algorithm if target is linked to an output.
            queue[0] = target_node
            if current_node == target_node:
                if len(children) > 1:
                    # Split node case.
                    queue.pop(0)
                    queue.extend(children)
                elif len(children) == 1:
                    # Merge node case.
                    queue[0] = children[0]
                # Break if there is a node in the queue linked to an output.
                if any(node.output[0] in graph_output_names for node in queue):
                    break

            # Sort the queue to follow a topological order (for next iterations).
            # Then remove repeated nodes.
            new_queue = []
            for node in sorted(queue, key=lambda n: node2index[n.name]):
                if node not in new_queue:
                    new_queue.append(node)
            queue = new_queue

        # Store the sequence if branches converge to the same node merge.
        if all(node == queue[0] for node in queue):
            outgoing_node = queue[0]
            try:
                loop_model = model.extract_model(split_node.output, outgoing_node.output)
            except Exception:
                # Prune sequence as it is not 'coherent'.
                continue
            else:
                loop_model = convert_model_to(loop_model, new_type=onnx.ModelProto)
                skip_loops.append(loop_model)
    return skip_loops


@ensure_model_type
def search_common_sequences(model, filter_fn):
    """Extract sub-models composed of nodes accepted by ``filter_fn``.

    All nodes for which ``filter_fn`` returns ``True`` are grouped into disjoint
    connected components. Each component is converted into an ONNX sub-model whose
    inputs/outputs are inferred from boundary tensors, ensuring graph outputs remain reachable.

    Args:
        model (Any): ONNX model to scan.
        filter_fn (callable of onnx.NodeProto): predicate selecting nodes to keep.

    Returns:
        list: extracted sub-models for every matching component.
    """
    # Read graph info.
    nodes = list(model.nodes())
    iname_to_nodes = model.input_name_to_nodes()
    graph_output_names = {out.name for out in model.output}
    tensor_to_exclude = list(model.get_initializer_name_set())
    tensor_to_exclude.extend([out for n in nodes for out in n.output if
                              out not in set(iname_to_nodes) | graph_output_names])

    # Collect sequences matching filter function.
    key_to_node = {node.name: node for node in nodes}
    visited = set(key for key, node in key_to_node.items() if not filter_fn(node))
    sequences = []
    for key, node in key_to_node.items():
        if key in visited:
            continue
        subgraph_keys = _collect_subgraph_keys_from(model, node, iname_to_nodes, visited)
        if not subgraph_keys:
            continue

        # Deduce input and output graph names from collected nodes.
        sub_nodes = [key_to_node[k] for k in subgraph_keys]
        input_names, output_names = _infer_partial_io(sub_nodes,
                                                      iname_to_nodes,
                                                      exclude=tensor_to_exclude,
                                                      graph_outputs=graph_output_names)

        # Extract sub-model and append to list.
        # Note we extract the model from extractor to prevent check model integrity.
        seq_model = model.extractor.extract_model(input_names, output_names)
        seq_model.graph.name = model.graph().name
        onnx.checker.check_model(seq_model, full_check=True)
        sequences.append(seq_model)

        # Update visited nodes.
        visited.update(subgraph_keys)

    # Sort sequences in topological order.
    return sorted(sequences, key=lambda seq: nodes.index(seq.graph.node[0]))


@ensure_model_type
def search_convertible_sequences(model, device=None, **device_kwargs):
    """Identify and return all fully convertible sequences in a quantized model.

    This function analyzes the model graph to find cycles (loops) and filter out them when
    they are not fully convertible. Finally, the function splits the model into sequences
    returning a list of sub-models that are suitable for conversion.

    Args:
        model (Any): the quantized ONNX model to analyze.
        device (akida.Device, optional): the Akida device to map the Akida sub models.
            Defaults to None.
        device_kwargs (dict, optional): parameters for computing device if device = None.
            Defaults to {}.

    Returns:
        list: a list of ONNX sub-models.
    """
    # Search convertible sequences in the model.
    # Note that all sequences are expected to have a single input and a single output
    # (quantize approach).
    convertible_sequences = search_common_sequences(model, is_convertible)

    # Search cycles in the sequences.
    loop_sequences = []
    for sub_model in convertible_sequences:
        sub_model = convert_model_to(sub_model, new_type=ONNXExtractorModel)
        # Transfer extractor to speed up runtime.
        sub_model.extractor = model.extractor
        loop_sequences.extend(search_cycles_on_model(sub_model))

    # Process cycles in several steps:
    # * remove cycles with full branches (non-mappable)
    filtered_loop_sequences = []
    for sub_model in loop_sequences:
        merged_node = sub_model.graph.node[-1]
        graph_input_names = [x.name for x in sub_model.graph.input]
        if len(set(merged_node.input[:]).intersection(graph_input_names)) == 0:
            continue
        filtered_loop_sequences.append(sub_model)
    # * merge all 'sequential' cycles
    # Note we can compare cycles by index since search_cycles returns them in topological order.
    idx = 0
    while idx < len(filtered_loop_sequences) - 1:
        first_sub_model = filtered_loop_sequences[idx]
        second_sub_model = filtered_loop_sequences[idx + 1]
        if (first_sub_model.graph.output[0].name == second_sub_model.graph.input[0].name):
            try:
                new_model = model.extract_model([first_sub_model.graph.input[0].name],
                                                [second_sub_model.graph.output[0].name])
                new_model = convert_model_to(new_model, new_type=onnx.ModelProto)
            except Exception:
                idx += 1
            else:
                filtered_loop_sequences.insert(idx, new_model)
                filtered_loop_sequences.remove(first_sub_model)
                filtered_loop_sequences.remove(second_sub_model)
        else:
            idx += 1

    # Check which sub_model is fully convertible.
    inames_to_nodes = model.input_name_to_nodes()
    onames_to_node = model.output_name_to_node()
    loop_sequences = []
    skip_outbounds = {}
    for sub_model in filtered_loop_sequences:
        try:
            # To avoid conversion issues, cycle must contain one quantized node
            # at the input and at the output. Reasons:
            # - InputData cannot be a split node and
            # - merge node develops into next quantized node
            parent_node = onames_to_node.get(sub_model.graph.input[0].name, None)
            child_nodes = inames_to_nodes.get(sub_model.graph.output[0].name, [None])
            if not (is_convertible(parent_node) and is_convertible(child_nodes[0])):
                continue
            new_input_name = find_no_initializer_inputs(parent_node, model.model)
            if len(new_output_name := [n.output[0] for n in child_nodes]) != 1:
                # We avoid converting cycles with multiple outputs.
                continue
            model_to_convert = model.extract_model(new_input_name, new_output_name)
            cnn2snn_convert_and_map(model_to_convert.model, device, **device_kwargs)
        except Exception:
            continue
        else:
            # Store the link between input_name and merge_node to compute sequences.
            loop_sequences.append(model_to_convert.model)
            skip_outbounds[sub_model.graph.input[0].name] = sub_model.graph.output[0].name

    # Split each sequence based on graph structure and skip_outbounds.
    base_sequences = []
    for sub_model in convertible_sequences:
        sub_model = convert_model_to(sub_model, new_type=ONNXExtractorModel)
        # Transfer extractor to speed up runtime.
        sub_model.extractor = model.extractor
        # Split sub_model.
        all_outputs = [n.output[0] for n in sub_model.nodes()]
        sub_skip_outbounds = {k: v for k, v in skip_outbounds.items()
                              if k in all_outputs and k != sub_model.output[0].name}
        base_sequences.extend(split_model_on_sequences(sub_model,
                                                       skip_outbounds=sub_skip_outbounds))
    return base_sequences


@ensure_model_type
def search_quantizable_sequences(model, compatibility_info):
    """Identify and return all quantizable sequences in a model.

    This function iteratively splits the model into sequences of consecutive nodes. Moreover,
    a sequence can contain multiple branches only if the nodes between the shared input and
    its merge node are fully quantizable.

    Args:
        model (Any): the model to split.
        compatibility_info (ModelCompatibilityInfo, optional): an existing ModelCompatibilityInfo
            object to accumulate incompatibility information during quantization.

    Returns:
        list of onnx.ModelProto: A list of sub-models.
    """
    # Search quantizable sequences in the model.
    quantizable_sequences = search_common_sequences(model, is_quantizable)

    # Split sequences with multiple inputs or input with multiple consumers in several ones.
    for idx in range(len(quantizable_sequences))[::-1]:
        sequence = quantizable_sequences[idx]
        graph_input_names = {x.name for x in sequence.graph.input}
        num_input_consumers = len([inp for n in sequence.graph.node
                                   for inp in n.input if inp in graph_input_names])
        if len(graph_input_names) > 1 or num_input_consumers > 1:
            sub_model = convert_model_to(sequence, new_type=ONNXExtractorModel)
            # Transfer extractor to speed up runtime.
            sub_model.extractor = model.extractor
            # Search cycles in sequence.
            cycles = search_cycles_on_model(sub_model)
            # Split sequence, keeping cycles intact.
            skip_outbounds = {cycle.graph.input[0].name: cycle.graph.output[0].name
                              for cycle in cycles
                              if cycle.graph.input[0].name not in graph_input_names}
            split_seqs = split_model_on_sequences(sub_model, skip_outbounds=skip_outbounds)
            # Replace original sequence with split ones.
            quantizable_sequences.remove(sequence)
            quantizable_sequences.extend(split_seqs)

    # Filter sequences that have full skippable nodes.
    SKIPPABLE_OPS = {'Relu', 'activation', 'MaxPool', 'Add', 'Mul', 'Concat', 'Flatten', 'Cast',
                     'Transpose'}
    quantizable_sequences = [seq for seq in quantizable_sequences
                             if not all(n.op_type in SKIPPABLE_OPS for n in seq.graph.node)]

    # Sort sequences in topological order.
    node_list = list(model.nodes())
    quantizable_sequences = sorted(quantizable_sequences,
                                   key=lambda seq: node_list.index(seq.graph.node[0]))

    # Tag nodes not in quantizable sequences as incompatibles.
    quantizable_node_names = {n.name for seq in quantizable_sequences for n in seq.graph.node}
    for node in model.nodes():
        if node.name not in quantizable_node_names:
            compatibility_info._set_incompatibility(
                node_sequence=[node],
                stage="Quantization",
                faulty_node=node.name,
                reason="No quantizable pattern found.")
    return quantizable_sequences
