#!/usr/bin/env python3

from __future__ import annotations

from typing import List
from typing import Tuple

import subprocess
import sys

ALLOWED_TAGS = [
    "ci",
    "cfix",
    "new",
    "cq",
    "feat",
    "fix",
    "perf",
    "refactor",
    "chore",
    "change",
    "release",
    "other",
    "imprv",
    "revert",
]


def parse_args() -> Tuple[str, str, str]:
    ref_name = sys.argv[1]
    old_ref = sys.argv[2]
    new_ref = sys.argv[3]
    return ref_name, old_ref, new_ref


def get_commit_subject(sha: str) -> str:
    data = subprocess.check_output(["git", "log", "-1", "--pretty=format:%s", sha])
    return data.decode()


def get_commit_shas(start_ref: str, end_ref: str) -> List[str]:
    arg = f"{start_ref}..{end_ref}"
    data = subprocess.check_output(["git", "rev-list", arg])
    text = data.decode()
    text = text.strip()
    if not text:
        # Can happen if commits are deleted after force push
        return []
    return text.split("\n")


def enforce_message_rules(subject: str) -> None:
    try:
        tag, subject = subject.split(": ", maxsplit=1)
    except ValueError:
        print("Unknown commit message tag:", subject)
        sys.exit(1)

    if tag not in ALLOWED_TAGS:
        print("Unknown commit message tag:", tag)
        sys.exit(1)

    if tag == "release":
        if not subject[0].isdigit():
            print('Release commits should be of the form "release: X.X.X"')
            sys.exit(1)
        return

    if not subject[0].isupper():
        print("First letter after tag must be uppercase")
        sys.exit(1)


def main(args: Tuple[str, str, str]) -> None:
    ref_name, old_ref, new_ref = args
    print("UPDATE", ref_name, old_ref, new_ref)

    if ref_name != "refs/heads/master":
        # We enforce only for master branch
        return

    shas = get_commit_shas(old_ref, new_ref)
    for sha in shas:
        print("Check", sha)
        subject = get_commit_subject(sha)
        enforce_message_rules(subject)


if __name__ == "__main__":
    args = parse_args()
    main(args)
