#!/usr/bin/env python
# ******************************************************************************
# Copyright 2024 Brainchip Holdings Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************

__all__ = ["reorder_ops"]

import uuid
import numpy as np
import onnx.numpy_helper
from onnx.helper import make_node
from onnxscript import rewriter
import onnx_ir as ir

from ...graph_tools import get_field
from ..model import ONNXModel
from .utils import safe_fail


def _add_rescale_node(onnx_model, input_name, output_name, scale=1.0, offset=0.0,
                      perm=[], nodes_to_remove=[]):
    """Creates a custom Rescale node as the sequence of (Cast) -> (Tranpose) -> (Add) -> (Mul)
    operations.

    Args:
        onnx_model (ONNXModel): The ONNX model to which the rescale node will be added.
        input_name (str): The name of the input tensor for the rescale operation.
        output_name (str): The name of the output tensor for the rescale operation.
        scale (float or list, optional): The scaling factors to apply to the model inputs.
            Defaults to 1.0.
        offset (float or list, optional): The offset values to apply after scaling the model
            inputs. Defaults to 0.0.
        perm (list, optional): The permutation to apply to the dimensions of the rescale
            node inputs. Defaults to [].
    """
    def _format_weight(x):
        x = np.array(x, dtype="float32")
        if np.size(x) == 1:
            x = np.squeeze(x)
        return x

    nodes, weights = [], []
    inode = input_name
    unique_id = str(uuid.uuid4())

    needs_cast = "Cast" in [node.op_type for node in nodes_to_remove]
    if needs_cast:
        nodes.append(make_node('Cast',
                               inputs=[inode],
                               outputs=[f"{input_name}/cast_{unique_id}"],
                               to=onnx.TensorProto.FLOAT))
        inode = nodes[-1].output[0]

    # Create a Transpose node if permutation changes the order of inputs
    needs_to_tranpose = any(x != idx for idx, x in enumerate(perm))
    if needs_to_tranpose:
        nodes.append(make_node('Transpose',
                               inputs=[inode],
                               outputs=[f"{input_name}/transposed_{unique_id}"],
                               perm=perm))
        inode = nodes[-1].output[0]

    # Create an Offset node if needed
    # Note we need to permute offset since transpose is applied as first operation
    needs_offset = np.any(offset != 0.0)
    if needs_offset:
        nodes.append(make_node("Add",
                               inputs=[inode, f"{input_name}/input_offset_{unique_id}"],
                               outputs=[f"{input_name}/shifted_{unique_id}"]))
        weights.append(onnx.numpy_helper.from_array(_format_weight(offset), nodes[-1].input[1]))
        inode = nodes[-1].output[0]

    # Create a Scale node
    # Note we need to permute scale since transpose is applied as first operation
    needs_scale = np.any(scale != 1.0)
    if needs_scale:
        nodes.append(make_node("Mul",
                               inputs=[inode, f"{input_name}/input_scale_{unique_id}"],
                               outputs=[f"{input_name}/scaled_{unique_id}"]))
        weights.append(onnx.numpy_helper.from_array(_format_weight(scale), nodes[-1].input[1]))

    # Replace last name if there are at least one node to append
    if len(nodes) > 0:
        nodes[-1].output[0] = output_name

    # Case when there are nodes to remove but no new nodes to add
    elif len(nodes_to_remove) > 0:
        onnx_model.replace_input_of_all_nodes(nodes_to_remove[0].output[0], input_name)

    # Add nodes to onnx/weights to model
    if weights:
        onnx_model.initializer_extend(weights)
    onnx_model.add_nodes(nodes)


@safe_fail
def reorder_ops(model):
    """Reorders model entry operations to follow the following sequence:
    (Cast) -> (Transpose) -> (Add) -> (Mul)

    Args:
        model (ONNXModel): The ONNX model to be processed.
    """
    def _get_variables(node):
        values = []
        for inode in node.input:
            try:
                values.append(model.get_variable(inode))
            except AssertionError:
                values.append(None)
        return values

    assert isinstance(model, ONNXModel)
    assert len(model.input) == 1, "Only a single input is supported"

    # Start with the first node only if the input is connected to one node
    # Otherwise, create a fake node to skip main loop
    first_nodes = model.input_name_to_nodes()[model.input[0].name]
    node = first_nodes[0] if len(first_nodes) == 1 else onnx.NodeProto()

    # Default Rescale parameters
    rank = len(model.get_input_shape(model.input[0].name))
    scale, offset, perm = np.ones([1] * rank), np.zeros([1] * rank), list(range(rank))
    nodes_to_remove = []

    # Main loop
    while node.op_type in ("Cast", "Mul", "Add", "Transpose"):
        variables = _get_variables(node)
        if node.op_type not in {"Cast", "Transpose"} and all(x is None for x in variables):
            # Nothing to do if there is no initializer in math ops
            break

        if node.op_type == "Cast":
            # Only Cast to float is supported
            if get_field(node, "to") != onnx.TensorProto.FLOAT:
                break

        # Apply transformation to current values
        # Note that it does not make sense for both inputs
        # to be initializers (the graph would be disconnected).
        new_value = next((x for x in variables if x is not None), None)
        if node.op_type == "Mul":
            # Scale is multiplied by the Mul scale
            scale = scale * new_value

        elif node.op_type == "Add":
            # Offset is increased by the Add offset scaled by current scale
            offset = offset + (new_value / scale)

        elif node.op_type == "Transpose":
            next_perm = get_field(node, "perm", default=list(range(rank))[::-1])
            # transpose rescale parameters
            scale = np.transpose(scale, next_perm)
            offset = np.transpose(offset, next_perm)
            perm = [next_perm[x] for x in perm]

        nodes_to_remove.append(node)

        # Break loop if current node has multiple outbounds
        outbounds = model.get_children(node)
        if len(outbounds) != 1:
            break

        # Get next node
        node = outbounds[0]

    if len(nodes_to_remove) == 0:
        # Nothing to do
        return

    rescale_output_name = nodes_to_remove[-1].output[0]

    # Add the rescale node as a sequence of (Cast), (Transpose) + (Add) + (Mul)
    rescale_input_name = model.input[0].name
    _add_rescale_node(model, rescale_input_name, rescale_output_name,
                      scale, offset, perm, nodes_to_remove)
    model.remove_nodes(nodes_to_remove)

    # As we add new nodes, we need to topologically sort the model graph
    model.topological_sort()
    model.clean_initializers()

    # Reorder rescale and flatten if applicable
    _reorder_rescale_flatten(model)


class _ReorderRescaleFlatten(rewriter.RewriteRuleClassBase):
    def __init__(self, rescale_op_type):
        super().__init__()
        self.rescale_op_type = rescale_op_type

    def pattern(self, op, x, rescale_value):
        # Rescale_op (Mul/Add) > Flatten
        rescale_op = getattr(op, self.rescale_op_type)
        return op.Flatten(rescale_op(x, rescale_value), _outputs=["flat_out"])

    def check(self, context, x, rescale_value, flat_out, **__):
        check_result = rewriter.MatchResult()
        flatten_node = flat_out.producer()

        if x.shape is None:
            return check_result.fail(f"Unknown shape for input ({x.name}).")

        if len(x.shape) != 4:
            return check_result.fail(
                f"Input ({x.name}) rank is expected to be 4. Actual: {len(x.shape)}")

        if ir.convenience.get_const_tensor(rescale_value) is None:
            return check_result.fail(f"{rescale_value.name} is expected to be constant.")

        # Check that axis is 1 for Flatten
        if (axis := flatten_node.attributes.get("axis", ir.AttrInt64("axis", 1)).as_int()) != 1:
            return check_result.fail(f"Flatten axis should be 1. Actual : {axis}")

        return check_result

    def rewrite(self, op, x, rescale_value, **__):
        # Reorder to Flatten > Rescale_op (Mul/Add)
        rescale_op = getattr(op, self.rescale_op_type)

        rescale_value_name = rescale_value.name
        rescale_value = ir.convenience.get_const_tensor(rescale_value).numpy()

        # rescale_value shape: (1, C, H, W)
        if rescale_value.size != 1:
            # Find axis to repeat
            for axis, dim in enumerate(rescale_value.shape[1:], start=1):
                if dim != 1:
                    continue

                # Repeat the dim with size 1 to match corresponding input dim
                rescale_value = np.repeat(rescale_value, x.shape[axis], axis=axis)

            # Flatten all dims except first, expected shape: (1, C*H*W)
            rescale_value = rescale_value.reshape((rescale_value.shape[0], -1))

        rescale_value = op.initializer(ir.tensor(rescale_value), name=rescale_value_name)
        return rescale_op(op.Flatten(x), rescale_value)


def _reorder_rescale_flatten(model):
    assert isinstance(model, ONNXModel)

    # Run shape inference as they are needed
    model.check_model()

    rules = []
    for rescale_op_type in ["Add", "Mul"]:
        rules.append(_ReorderRescaleFlatten.rule(rescale_op_type=rescale_op_type))

    model.rewrite(rules)
