# Extension module for FindCalligraphers main program

from pathlib import Path

import cv2 as cv
import numpy as np
import torch
from numpy import ndarray

from model import CalligraphyClassifyModel


MODEL_PT_FILE = Path("model.pt")

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def load_calligraphy_model():
    model = CalligraphyClassifyModel()
    model.load_state_dict(torch.load(MODEL_PT_FILE))
    model.to(DEVICE)
    model.eval()
    return model


def find_contour_boxes(contours: list[ndarray], image_width: int, image_height: int):
    boxes = []
    for contour in contours:
        x, y, w, h = cv.boundingRect(contour)

        # ignore too small/big rectangles
        if (
            w < 0.05 * image_width
            or h < 0.05 * image_height
            or w > 0.5 * image_width
            or h > 0.5 * image_height
        ):
            continue

        # Detect overlapped rectangle
        overlapped = None
        for idx, (xb, yb, wb, hb) in enumerate(boxes):
            ox1 = max(x, xb)
            oy1 = max(y, yb)
            ox2 = min(x + w, xb + wb)
            oy2 = min(y + h, yb + hb)
            if (
                ox1 < ox2
                and oy1 < oy2
                and (ox2 - ox1) > 0.3 * min(w, wb)
                and (oy2 - oy1) > 0.3 * min(h, hb)
            ):
                overlapped = idx
                break

        if overlapped is not None:
            # Merge overlapped rectangle
            (xb, yb, wb, hb) = boxes[overlapped]
            ox1 = min(x, xb)
            oy1 = min(y, yb)
            ox2 = max(x + w, xb + wb)
            oy2 = max(y + h, yb + hb)
            boxes[overlapped] = ox1, oy1, ox2 - ox1, oy2 - oy1
        else:
            boxes.append((x, y, w, h))

    return boxes


def extract_characters(image_input: ndarray):
    img = cv.cvtColor(image_input, cv.COLOR_BGR2GRAY)
    _, img = cv.threshold(img, 127, 255, cv.THRESH_BINARY)
    img = cv.morphologyEx(img, cv.MORPH_CLOSE, np.ones((15, 15), np.uint8))

    contours, _ = cv.findContours(img, cv.RETR_TREE, cv.CHAIN_APPROX_SIMPLE)
    boxes = find_contour_boxes(contours, img.shape[1], img.shape[0])

    image_with_boxes = image_input.copy()
    for box in boxes:
        cv.rectangle(image_with_boxes, box, (0, 255, 0), 5, 8)
    return image_with_boxes, boxes


def classify_calligraphers(image_input: ndarray, boxes: list[tuple]):
    if not boxes:
        return []

    if not hasattr(classify_calligraphers, "model"):
        classify_calligraphers.model = load_calligraphy_model()

    model = classify_calligraphers.model
    image = cv.cvtColor(image_input, cv.COLOR_BGR2GRAY)

    # form input tensor
    model_input = torch.zeros(len(boxes), *model.INPUT_SIZE, dtype=torch.float)
    for idx, (x, y, w, h) in enumerate(boxes):
        roi = cv.resize(image[y : y + h, x : x + w], model.INPUT_SIZE)
        model_input[idx] = torch.tensor(roi, dtype=torch.float)

    # predict
    model_input = model_input.to(DEVICE)
    model_output = model(model_input)
    labels = model.get_labels(model_output)

    # collect
    counter = {}
    for label in map(int, labels):
        counter[label] = counter.get(label, 0) + 1
    return sorted(counter.keys(), key=lambda x: counter[x], reverse=True)


def find_calligraphers(input_image: ndarray):
    image_with_boxes, boxes = extract_characters(input_image)
    calligraphers = classify_calligraphers(image_with_boxes, boxes)
    return image_with_boxes, calligraphers
