#! /usr/bin/env python3
# pyright: strict

# --------------------------------------------------------------------
# --- Cachegrind's differencer.                         cg_diff.in ---
# --------------------------------------------------------------------

#  This file is part of Cachegrind, a Valgrind tool for cache
#  profiling programs.
#
#  Copyright (C) 2002-2023 Nicholas Nethercote
#     njn@valgrind.org
#
#  This program is free software; you can redistribute it and/or
#  modify it under the terms of the GNU General Public License as
#  published by the Free Software Foundation; either version 2 of the
#  License, or (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful, but
#  WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
#  General Public License for more details.
#
#  You should have received a copy of the GNU General Public License
#  along with this program; if not, see <http://www.gnu.org/licenses/>.
#
#  The GNU General Public License is contained in the file COPYING.

"""
This script diffs Cachegrind output files.
"""

# Use `make pydiff` to "build" this script every time it is changed. This runs
# the formatters, type-checkers, and linters on `cg_diff.in` and then generates
# `cg_diff`.
#
# This is a cut-down version of `cg_annotate.in`.

from __future__ import annotations

import re
import sys
from argparse import ArgumentParser, Namespace
from collections import defaultdict
from typing import Callable, DefaultDict, NewType, NoReturn

SearchAndReplace = Callable[[str], str]


class Args(Namespace):
    """
    A typed wrapper for parsed args.

    None of these fields are modified after arg parsing finishes.
    """

    mod_filename: SearchAndReplace
    mod_funcname: SearchAndReplace
    cgout_filename1: str
    cgout_filename2: str

    @staticmethod
    def parse() -> Args:
        # We support Perl-style `s/old/new/flags` search-and-replace
        # expressions, because that's how this option was implemented in the
        # old Perl version of `cg_diff`. This requires conversion from
        # `s/old/new/` style to `re.sub`. The conversion isn't a perfect
        # emulation of Perl regexps (e.g. Python uses `\1` rather than `$1` for
        # using captures in the `new` part), but it should be close enough. The
        # only supported flags are `g` (global) and `i` (ignore case).
        def search_and_replace(regex: str | None) -> SearchAndReplace:
            if regex is None:
                return lambda s: s

            # Extract the parts of a `s/old/new/tail` regex. `(?<!\\)/` is an
            # example of negative lookbehind. It means "match a forward slash
            # unless preceded by a backslash".
            m = re.match(r"s/(.*)(?<!\\)/(.*)(?<!\\)/(g|i|gi|ig|)$", regex)
            if m is None:
                raise ValueError

            # Forward slashes must be escaped in an `s/old/new/` expression,
            # but we then must unescape them before using them with `re.sub`
            pat = m.group(1).replace(r"\/", r"/")
            repl = m.group(2).replace(r"\/", r"/")
            tail = m.group(3)

            if "g" in tail:
                count = 0  # unlimited
            else:
                count = 1

            if "i" in tail:
                flags = re.IGNORECASE
            else:
                flags = re.RegexFlag(0)

            return lambda s: re.sub(re.compile(pat, flags=flags), repl, s, count=count)

        p = ArgumentParser(description="Diff two Cachegrind output files.")

        p.add_argument("--version", action="version", version="%(prog)s-3.21.0.GIT")

        p.add_argument(
            "--mod-filename",
            type=search_and_replace,
            metavar="REGEX",
            default=search_and_replace(None),
            help="a search-and-replace regex applied to filenames, e.g. "
            "`s/prog[0-9]/progN/`",
        )
        p.add_argument(
            "--mod-funcname",
            type=search_and_replace,
            metavar="REGEX",
            default=search_and_replace(None),
            help="like --mod-filename, but for function names",
        )

        p.add_argument(
            "cgout_filename1",
            nargs=1,
            metavar="cachegrind-out-file1",
            help="file produced by Cachegrind",
        )
        p.add_argument(
            "cgout_filename2",
            nargs=1,
            metavar="cachegrind-out-file2",
            help="file produced by Cachegrind",
        )

        return p.parse_args(namespace=Args())


# Args are stored in a global for easy access.
args = Args.parse()

# A single instance of this class is constructed, from `args` and the `events:`
# line in the cgout file.
class Events:
    # The event names.
    events: list[str]

    def __init__(self, text: str) -> None:
        self.events = text.split()
        self.num_events = len(self.events)

    def mk_cc(self, text: str) -> Cc:
        """Raises a `ValueError` exception on syntax error."""
        # This is slightly faster than a list comprehension.
        counts = list(map(int, text.split()))

        if len(counts) == self.num_events:
            pass
        elif len(counts) < self.num_events:
            # Add zeroes at the end for any missing numbers.
            counts.extend([0] * (self.num_events - len(counts)))
        else:
            raise ValueError

        return Cc(counts)

    def mk_empty_cc(self) -> Cc:
        # This is much faster than a list comprehension.
        return Cc([0] * self.num_events)


class Cc:
    """
    This is a dumb container for counts.

    It doesn't know anything about events, i.e. what each count means. It can
    do basic operations like `__iadd__` and `__eq__`, and anything more must be
    done elsewhere. `Events.mk_cc` and `Events.mk_empty_cc` are used for
    construction.
    """

    # Always the same length as `Events.events`.
    counts: list[int]

    def __init__(self, counts: list[int]) -> None:
        self.counts = counts

    def __repr__(self) -> str:
        return str(self.counts)

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, Cc):
            return NotImplemented
        return self.counts == other.counts

    def __iadd__(self, other: Cc) -> Cc:
        for i, other_count in enumerate(other.counts):
            self.counts[i] += other_count
        return self

    def __isub__(self, other: Cc) -> Cc:
        for i, other_count in enumerate(other.counts):
            self.counts[i] -= other_count
        return self


# A paired filename and function name.
Flfn = NewType("Flfn", tuple[str, str])

# Per-function CCs.
DictFlfnCc = DefaultDict[Flfn, Cc]


def die(msg: str) -> NoReturn:
    print("cg_diff: error:", msg, file=sys.stderr)
    sys.exit(1)


def read_cgout_file(cgout_filename: str) -> tuple[str, Events, DictFlfnCc, Cc]:
    # The file format is described in Cachegrind's manual.
    try:
        cgout_file = open(cgout_filename, "r", encoding="utf-8")
    except OSError as err:
        die(f"{err}")

    with cgout_file:
        cgout_line_num = 0

        def parse_die(msg: str) -> NoReturn:
            die(f"{cgout_file.name}:{cgout_line_num}: {msg}")

        def readline() -> str:
            nonlocal cgout_line_num
            cgout_line_num += 1
            return cgout_file.readline()

        # Read "desc:" lines.
        while line := readline():
            if m := re.match(r"desc:\s+(.*)", line):
                # The "desc:" lines are unused.
                pass
            else:
                break

        # Read "cmd:" line. (`line` is already set from the "desc:" loop.)
        if m := re.match(r"cmd:\s+(.*)", line):
            cmd = m.group(1)
        else:
            parse_die("missing a `command:` line")

        # Read "events:" line.
        line = readline()
        if m := re.match(r"events:\s+(.*)", line):
            events = Events(m.group(1))
        else:
            parse_die("missing an `events:` line")

        curr_fl = ""
        curr_flfn = Flfn(("", ""))

        # Different places where we accumulate CC data.
        dict_flfn_cc: DictFlfnCc = defaultdict(events.mk_empty_cc)
        summary_cc = None

        # Compile the one hot regex.
        count_pat = re.compile(r"(\d+)\s+(.*)")

        # Line matching is done in order of pattern frequency, for speed.
        while True:
            line = readline()

            if m := count_pat.match(line):
                # The line_num isn't used.
                try:
                    cc = events.mk_cc(m.group(2))
                except ValueError:
                    parse_die("malformed or too many event counts")

                # Record this CC at the function level.
                flfn_cc = dict_flfn_cc[curr_flfn]
                flfn_cc += cc

            elif line.startswith("fn="):
                curr_flfn = Flfn((curr_fl, args.mod_funcname(line[3:-1])))

            elif line.startswith("fl="):
                # A longstanding bug: the use of `--mod-filename` makes it
                # likely that some files won't be found when annotating. This
                # doesn't matter much, because we use line number 0 for all
                # diffs anyway. It just means we get "This file was unreadable"
                # for modified filenames rather than a single "<unknown (line
                # 0)>" CC.
                curr_fl = args.mod_filename(line[3:-1])
                # A `fn=` line should follow, overwriting the "???".
                curr_flfn = Flfn((curr_fl, "???"))

            elif m := re.match(r"summary:\s+(.*)", line):
                try:
                    summary_cc = events.mk_cc(m.group(1))
                except ValueError:
                    parse_die("too many event counts")

            elif line == "":
                break  # EOF

            elif line == "\n" or line.startswith("#"):
                # Skip empty lines and comment lines.
                pass

            else:
                parse_die(f"malformed line: {line[:-1]}")

    # Check if summary line was present.
    if not summary_cc:
        parse_die("missing `summary:` line, aborting")

    # Check summary is correct.
    total_cc = events.mk_empty_cc()
    for flfn_cc in dict_flfn_cc.values():
        total_cc += flfn_cc
    if summary_cc != total_cc:
        msg = (
            "`summary:` line doesn't match computed total\n"
            f"- summary: {summary_cc}\n"
            f"- total:   {total_cc}"
        )
        parse_die(msg)

    return (cmd, events, dict_flfn_cc, summary_cc)


def main() -> None:
    filename1 = args.cgout_filename1[0]
    filename2 = args.cgout_filename2[0]

    (cmd1, events1, dict_flfn_cc1, summary_cc1) = read_cgout_file(filename1)
    (cmd2, events2, dict_flfn_cc2, summary_cc2) = read_cgout_file(filename2)

    if events1.num_events != events2.num_events:
        die("events don't match")

    # Subtract file 1's CCs from file 2's CCs, at the Flfn level.
    for flfn, flfn_cc1 in dict_flfn_cc1.items():
        flfn_cc2 = dict_flfn_cc2[flfn]
        flfn_cc2 -= flfn_cc1
    summary_cc2 -= summary_cc1

    print(f"desc: Files compared:   {filename1}; {filename2}")
    print(f"cmd: {cmd1}; {cmd2}")
    print("events:", *events1.events, sep=" ")

    # Sort so the output is deterministic.
    def key(flfn_and_cc: tuple[Flfn, Cc]) -> Flfn:
        return flfn_and_cc[0]

    for flfn, flfn_cc2 in sorted(dict_flfn_cc2.items(), key=key):
        # Use `0` for the line number because we don't try to give line-level
        # CCs, due to the possibility of code changes causing line numbers to
        # move around.
        print(f"fl={flfn[0]}")
        print(f"fn={flfn[1]}")
        print("0", *flfn_cc2.counts, sep=" ")

    print("summary:", *summary_cc2.counts, sep=" ")


if __name__ == "__main__":
    main()
