#!/usr/bin/env python
# ******************************************************************************
# Copyright 2025 Brainchip Holdings Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************

__all__ = ["convert_reshape_to_transpose"]

from onnxscript.rewriter import RewriteRuleClassBase
from onnxscript.rewriter._basics import MatchResult
import onnx_ir as ir

from ..model import ONNXModel
from .utils import safe_fail


class _ReshapeToTransposeRule(RewriteRuleClassBase):
    def pattern(self, op, x):
        return op.Reshape(x, _allow_other_inputs=True)

    def check(self, context, **_):
        check_result = MatchResult()
        reshape_node = context.nodes[0]

        input_shape = reshape_node.inputs[0].shape
        if input_shape is None:
            return check_result.fail("Unknown shape for Reshape input.")

        output_shape = reshape_node.outputs[0].shape
        if output_shape is None:
            return check_result.fail("Unknown shape for Reshape output.")

        # The reshape corresponds to a simple transpose if:
        # - The input is 4D
        # - The last dimension of the input is 1 (i.e., NHWC format with single channel)
        # - The output shape is a permutation of the input shape corresponding to: NHWC --> NCHW.
        if (len(input_shape) != 4 or
           input_shape[3] != 1 or
           output_shape != ir.Shape([input_shape[0], input_shape[3], *input_shape[1:3]])):
            return check_result.fail("Reshape does not correspond to a simple transpose.")

        return check_result

    def rewrite(self, op, x, **_):
        return op.Transpose(x, perm=[0, 3, 1, 2])


@safe_fail
def convert_reshape_to_transpose(model):
    """
    Convert Reshape nodes that only permute the data layout into Transpose nodes.

    Args:
        model (ONNXModel): The input ONNX model.

    Returns:
        ONNXModel: The transformed ONNX model with applicable Reshape nodes converted to
            Transpose nodes.
    """
    assert isinstance(model, ONNXModel)
    model = model.rewrite([_ReshapeToTransposeRule().rule()])
