#!/usr/bin/env python3

import os
import subprocess
import re
import sys
import argparse
import difflib

import tempfile

# Regex to match brew/cask/tap/mas lines
PKG_RE = re.compile(r'^\s*(brew|cask|tap|mas)\s+["\']([^"\']+)["\'](.*)$')

def colorize_diff(lines):
    for line in lines:
        if line.startswith('+') and not line.startswith('+++'):
            yield f"\033[32m{line}\033[0m"
        elif line.startswith('-') and not line.startswith('---'):
            yield f"\033[31m{line}\033[0m"
        elif line.startswith('^'):
            yield f"\033[36m{line}\033[0m"
        else:
            yield line

class Entry:
    def sort_key(self): raise NotImplementedError()
    def to_lines(self): raise NotImplementedError()

class PackageEntry(Entry):
    def __init__(self, pkg_type, name, options, comments=None):
        self.pkg_type = pkg_type
        self.name = name
        self.options = options.strip()
        self.comments = comments or []

    def sort_key(self):
        order = {'tap': 0, 'brew': 1, 'cask': 2, 'mas': 3}
        return (order.get(self.pkg_type, 4), self.name)

    def to_lines(self):
        res = list(self.comments)
        pkg_line = f'{self.pkg_type} "{self.name}"'
        if self.options:
            if not self.options.startswith(','):
                pkg_line += ' '
            pkg_line += self.options
        res.append(pkg_line)
        return res

class TextEntry(Entry):
    def __init__(self, lines, is_header=True):
        self.lines = lines
        self.is_header = is_header

    def sort_key(self):
        # Header is -1, Trailing is 5 (after all package types 0-4)
        return (-1 if self.is_header else 5, "")

    def to_lines(self):
        return self.lines

def get_repo_root():
    script_dir = os.path.dirname(os.path.realpath(__file__))
    try:
        root = subprocess.check_output(['git', 'rev-parse', '--show-toplevel'],
                                       cwd=script_dir,
                                       stderr=subprocess.STDOUT).decode().strip()
        return root
    except subprocess.CalledProcessError:
        # If not in a git repo, go up 2 levels from script_dir (bin/macos/)
        return os.path.abspath(os.path.join(script_dir, '..', '..'))

def get_ignore_list(repo_root):
    ignore = set()
    paths = [
        os.path.join(repo_root, '.Brewfile.ignore'),
        os.path.expanduser('~/.Brewfile.ignore'),
        os.path.expanduser('~/.config/homebrew/ignore')
    ]
    for path in paths:
        if os.path.exists(path):
            with open(path) as f:
                for line in f:
                    line = line.split('#')[0].strip()
                    if line:
                        ignore.add(line)
    return ignore

def get_current_packages(args):
    """Runs brew bundle dump and returns lines."""
    env = os.environ.copy()
    env["HOMEBREW_NO_AUTO_UPDATE"] = "1"
    env["HOMEBREW_NO_INSTALL_CLEANUP"] = "1"
    env["HOMEBREW_NO_ENV_HINTS"] = "1"

    cmd = ['brew', 'bundle', 'dump', '--file=-']
    if args.no_mas: cmd.append('--no-mas')
    if args.no_vscode: cmd.append('--no-vscode')

    try:
        output = subprocess.check_output(cmd, env=env, stderr=subprocess.DEVNULL).decode()
        return output.splitlines()
    except subprocess.CalledProcessError:
        print("Error: 'brew bundle dump' failed. Is homebrew installed?", file=sys.stderr)
        sys.exit(1)

def parse_brewfile(content):
    """
    Parses Brewfile content.
    Returns:
    - entries: list of Entry objects
    - conditional_pkgs: set of (type, name)
    - footer: string (everything from first conditional onwards)
    """
    lines = content.splitlines()
    conditional_pkgs = set()

    # Find the start of the first conditional block
    first_conditional_idx = -1
    in_conditional = 0

    for i, line in enumerate(lines):
        stripped = line.strip()
        if stripped.startswith(('if ', 'unless ', 'case ', 'def ', 'begin ')) and not stripped.endswith('; end'):
            if first_conditional_idx == -1:
                # Look back for comments that might belong to this block
                j = i - 1
                while j >= 0 and (lines[j].strip().startswith('#') or not lines[j].strip()):
                    j -= 1
                first_conditional_idx = j + 1
            in_conditional += 1

        if in_conditional > 0:
            match = PKG_RE.match(line)
            if match:
                conditional_pkgs.add((match.group(1), match.group(2)))

            if stripped == 'end' or stripped.endswith('; end'):
                in_conditional -= 1

    unconditional_lines = lines[:first_conditional_idx] if first_conditional_idx != -1 else lines
    footer = "\n".join(lines[first_conditional_idx:]) if first_conditional_idx != -1 else ""

    entries = []
    comment_buffer = []
    first_pkg_seen = False

    for line in unconditional_lines:
        match = PKG_RE.match(line)
        if match:
            if not first_pkg_seen:
                if comment_buffer:
                    entries.append(TextEntry(comment_buffer, is_header=True))
                comment_buffer = []
                first_pkg_seen = True
            entries.append(PackageEntry(match.group(1), match.group(2), match.group(3), comment_buffer))
            comment_buffer = []
        else:
            comment_buffer.append(line)

    if comment_buffer:
        entries.append(TextEntry(comment_buffer, is_header=False))

    return entries, conditional_pkgs, footer

def main(args):
    if sys.platform != "darwin":
        print(f"Warning: Running on {sys.platform}. Brewfile is primarily for macOS.", file=sys.stderr)

    repo_root = get_repo_root()
    brewfile_path = os.path.join(repo_root, 'Brewfile')

    if not os.path.exists(brewfile_path):
        print(f"Error: Brewfile not found at {brewfile_path}", file=sys.stderr)
        sys.exit(1)

    with open(brewfile_path) as f:
        old_content = f.read()

    old_entries, conditional_pkgs, footer = parse_brewfile(old_content)
    ignore_list = get_ignore_list(repo_root)

    old_pkg_map = {}
    for e in old_entries:
        if isinstance(e, PackageEntry):
            key = (e.pkg_type, e.name)
            if key not in old_pkg_map:
                old_pkg_map[key] = e

    dumped_lines = get_current_packages(args)
    dumped_pkgs = []
    for line in dumped_lines:
        match = PKG_RE.match(line)
        if match:
            pkg_type, pkg_name, pkg_options = match.group(1), match.group(2), match.group(3)
            if pkg_name in ignore_list or (pkg_type, pkg_name) in conditional_pkgs:
                continue
            dumped_pkgs.append(PackageEntry(pkg_type, pkg_name, pkg_options))

    new_entries = []
    for e in old_entries:
        if isinstance(e, TextEntry) and e.is_header:
            new_entries.append(e)

    seen_in_new = set()
    added_count = 0
    removed_count = 0
    merged_count = 0

    if args.add_only:
        for e in old_entries:
            if isinstance(e, PackageEntry):
                new_entries.append(e)
                seen_in_new.add((e.pkg_type, e.name))
        for d in dumped_pkgs:
            if (d.pkg_type, d.name) not in seen_in_new:
                new_entries.append(d)
                seen_in_new.add((d.pkg_type, d.name))
                added_count += 1
    else:
        for d in dumped_pkgs:
            key = (d.pkg_type, d.name)
            if key in seen_in_new: continue
            if key in old_pkg_map:
                merged = old_pkg_map[key]
                if not merged.options:
                    merged.options = d.options
                new_entries.append(merged)
                merged_count += 1
            else:
                new_entries.append(d)
                added_count += 1
            seen_in_new.add(key)

        # Check for removals
        for key in old_pkg_map:
            if key not in seen_in_new:
                removed_count += 1

    for e in old_entries:
        if isinstance(e, TextEntry) and not e.is_header:
            new_entries.append(e)

    new_entries.sort(key=lambda x: x.sort_key())

    output_lines = []
    last_type = None
    for e in new_entries:
        if isinstance(e, PackageEntry):
            if last_type and e.pkg_type != last_type:
                output_lines.append("")
            last_type = e.pkg_type

        lines = e.to_lines()
        # If we just added a blank line, and the new lines start with one, skip the first new line
        if output_lines and output_lines[-1] == "" and lines and lines[0] == "":
            output_lines.extend(lines[1:])
        else:
            output_lines.extend(lines)

    new_content = "\n".join(output_lines)
    if footer:
        if output_lines and output_lines[-1].strip():
            new_content += "\n\n"
        new_content += footer.strip() + "\n"
    else:
        new_content += "\n"

    if new_content == old_content:
        print("Brewfile is already up to date.")
    else:
        if args.dry_run:
            print("Changes detected (dry run):")
            diff = list(difflib.unified_diff(
                old_content.splitlines(keepends=True),
                new_content.splitlines(keepends=True),
                fromfile='Brewfile (original)',
                tofile='Brewfile (new)'
            ))
            if sys.stdout.isatty():
                sys.stdout.writelines(colorize_diff(diff))
            else:
                sys.stdout.writelines(diff)
        else:
            dir_name = os.path.dirname(brewfile_path)
            with tempfile.NamedTemporaryFile('w', dir=dir_name, delete=False) as tf:
                tf.write(new_content)
                tempname = tf.name
            os.replace(tempname, brewfile_path)
            print("Brewfile updated.")

        if args.verbose:
            print(f"Summary: {added_count} added, {removed_count} removed, {merged_count} kept/merged.")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Update Brewfile while preserving conditionals.")
    parser.add_argument("--dry-run", action="store_true", help="Show changes without applying them.")
    parser.add_argument("--add-only", action="store_true", help="Only add missing entries, do not remove existing ones.")
    parser.add_argument("--verbose", "-v", action="store_true", help="Print summary of changes.")
    parser.add_argument("--no-mas", action="store_true", help="Do not include Mac App Store apps.")
    parser.add_argument("--no-vscode", action="store_true", help="Do not include VSCode extensions.")
    args = parser.parse_args()
    main(args)
