import torch


class CalligraphyClassifyModel(torch.nn.Module):
    """
    CNN model to classify calligraphy images
    """

    LABEL_SET = [
        "wxz",
        "yzq",
        "lgq",
        "sgt",
        "smh",
        "mf",
        "htj",
        "oyx",
        "zmf",
        "csl",
        "wzm",
        "lqs",
        "yyr",
        "hy",
        "bdsr",
        "fwq",
        "gj",
        "shz",
        "mzd",
        "lx",
    ]
    INPUT_SIZE = (64, 64)
    OUTPUT_SIZE = 20

    def __init__(self):
        super().__init__()

        self.conv1 = torch.nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.relu1 = torch.nn.ReLU()
        self.pool1 = torch.nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv2 = torch.nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.relu2 = torch.nn.ReLU()
        self.pool2 = torch.nn.MaxPool2d(kernel_size=2, stride=2)

        self.fc = torch.nn.Linear(32 * 16 * 16, self.OUTPUT_SIZE)

    def forward(self, x: torch.Tensor):
        x = self.conv1(x.unsqueeze(1))
        x = self.relu1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.pool2(x)
        x = self.fc(x.view(x.size(0), -1))
        return x

    def get_labels(self, model_output: torch.Tensor):
        return torch.max(model_output, 1)[1]
