#!/usr/bin/env python
# ******************************************************************************
# Copyright 2025 Brainchip Holdings Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
import argparse
import onnx

from .convert import convert, print_report


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("-m", "--model", type=str, required=True, help="Model to check")
    parser.add_argument("-id", "--input_dtype", type=str, default="uint8",
                        help="Numpy-like dtype format to quantize the inputs")
    parser.add_argument("--input_shape",
                        type=lambda x: tuple(map(int, x.split(','))),
                        default=None,
                        help="Shape to use for input_shape (Excluding batch dimension). "
                        "Provide comma separated list for the shape. All values must be "
                        "integers > 0. e.g. --input_shape 3,256,256.")
    parser.add_argument("-s", "--save_model", type=str,
                        help="Save model to draw in Netron (not inference)")
    return parser.parse_args()


def main():
    args = get_args()

    # Load model.
    model = onnx.load(args.model)

    # Check model compatibility and convert to HybridModel.
    hybrid_model, model_compatibility_info = convert(
        model, input_shape=args.input_shape, input_dtype=args.input_dtype)

    # Print report.
    print_report(hybrid_model, model_compatibility_info)

    # Save model if needed.
    if args.save_model:
        model_compatibility_info.save_tagged_model(args.save_model)
        print(f"[INFO]: Save modified graph in {args.save_model}.")
