"""
subagent_qa_scorer.py

Reads a captured subagent output file and scores it using a two-tier approach:
  - Tier 1 (Haiku): classify as COMPLETE, PARTIAL, or MISSED
  - Tier 2 (Sonnet): only on PARTIAL/MISSED — produce a specific retry instruction

Writes results to state/subagent_outputs/<session_key>.score.json.
"""

import sys
import json
import os
import re
from datetime import datetime, timezone


AUTH_PROFILES_PATH = os.path.expanduser(
    "~/.openclaw/agents/main/agent/auth-profiles.json"
)
CAPTURE_DIR = os.path.join(
    os.environ.get("CLAWSTIN_WORKSPACE", os.path.expanduser("~/.openclaw/workspace")),
    "state", "subagent_outputs"
)

HAIKU_MODEL = "claude-haiku-4-5"
SONNET_MODEL = "claude-sonnet-4-6"


def load_api_key():
    """Load Anthropic API key from auth-profiles.json, with env var fallback."""
    # Try profiles file first
    if os.path.exists(AUTH_PROFILES_PATH):
        try:
            with open(AUTH_PROFILES_PATH, "r", encoding="utf-8") as f:
                data = json.load(f)

            # Try profiles dict (e.g. {"anthropic:default": {"key": "..."}})
            profiles = data.get("profiles", {})
            if isinstance(profiles, dict):
                for profile_key, profile_val in profiles.items():
                    if isinstance(profile_val, dict) and profile_val.get("provider", "").lower() == "anthropic":
                        key = profile_val.get("key", "") or profile_val.get("apiKey", "")
                        if key:
                            return key
            elif isinstance(profiles, list):
                for profile in profiles:
                    if isinstance(profile, dict) and profile.get("provider", "").lower() == "anthropic":
                        key = profile.get("apiKey", "") or profile.get("key", "")
                        if key:
                            return key

            # Try top-level apiKey
            key = data.get("apiKey", "")
            if key:
                return key
        except Exception:
            pass

    # Env var fallback
    key = os.environ.get("ANTHROPIC_API_KEY", "")
    if key:
        return key

    return None


def call_anthropic(api_key, model, prompt):
    """Send a single-turn prompt to the Anthropic API and return the text response."""
    import anthropic

    client = anthropic.Anthropic(api_key=api_key)
    message = client.messages.create(
        model=model,
        max_tokens=1024,
        messages=[{"role": "user", "content": prompt}],
    )
    # Extract text from response
    for block in message.content:
        if hasattr(block, "text"):
            return block.text
    return ""


def parse_verdict(response_text):
    """
    Parse Tier 1 response. Expected format:
    VERDICT: <COMPLETE|PARTIAL|MISSED> — <reason>
    Returns (verdict, reason) or raises ValueError.
    """
    match = re.search(
        r"VERDICT:\s*(COMPLETE|PARTIAL|MISSED)\s*[—\-–]\s*(.+)",
        response_text,
        re.IGNORECASE,
    )
    if not match:
        raise ValueError(f"Could not parse verdict from response: {response_text!r}")
    verdict = match.group(1).upper()
    reason = match.group(2).strip()
    return verdict, reason


def main():
    if len(sys.argv) < 2:
        print("ERROR: missing session_key argument")
        print(f"Usage: python3 {sys.argv[0]} <session_key>")
        sys.exit(1)

    session_key = sys.argv[1]

    capture_path = os.path.join(CAPTURE_DIR, f"{session_key}.json")
    if not os.path.exists(capture_path):
        print(f"ERROR: capture file not found: {capture_path}")
        sys.exit(1)

    try:
        with open(capture_path, "r", encoding="utf-8") as f:
            capture = json.load(f)
    except Exception as e:
        print(f"ERROR: failed to read capture file: {e}")
        sys.exit(1)

    original_task = capture.get("original_task", "")
    final_output = capture.get("final_output", "")

    if not original_task or not final_output:
        print("ERROR: capture file missing original_task or final_output")
        sys.exit(1)

    if not isinstance(original_task, str) or not isinstance(final_output, str):
        print("ERROR: original_task and final_output must be strings")
        sys.exit(1)

    api_key = load_api_key()
    if not api_key:
        print("ERROR: could not find Anthropic API key")
        sys.exit(1)

    # --- Tier 1: Haiku triage ---
    tier1_prompt = (
        f"Task: {original_task}\n"
        f"Output: {final_output}\n\n"
        "Classify this subagent output as COMPLETE, PARTIAL, or MISSED.\n"
        "- COMPLETE: task was fully accomplished\n"
        "- PARTIAL: task was attempted but something was missed or incomplete\n"
        "- MISSED: output doesn't address the task or failed entirely\n"
        'Return exactly one line: "VERDICT: <COMPLETE|PARTIAL|MISSED> \u2014 <one sentence reason>"'
    )

    try:
        tier1_response = call_anthropic(api_key, HAIKU_MODEL, tier1_prompt)
    except Exception as e:
        print(f"ERROR: Tier 1 API call failed: {e}")
        sys.exit(1)

    try:
        verdict, reason = parse_verdict(tier1_response)
    except ValueError as e:
        print(f"ERROR: {e}")
        sys.exit(1)

    print(f"VERDICT: {verdict} — {reason}")

    result = {
        "session_key": session_key,
        "verdict": verdict,
        "reason": reason,
        "scored_at": datetime.now(timezone.utc).isoformat(),
    }

    # --- Tier 2: Sonnet deep analysis (only if not COMPLETE) ---
    if verdict in ("PARTIAL", "MISSED"):
        tier2_prompt = (
            f"A subagent was given this task:\n{original_task}\n\n"
            f"Its output was:\n{final_output}\n\n"
            f"Tier 1 verdict: {verdict} — {reason}\n\n"
            "Analyze why it fell short and write a retry instruction that would "
            "correct the failure. Be specific."
        )

        try:
            retry_instruction = call_anthropic(api_key, SONNET_MODEL, tier2_prompt)
            result["retry_instruction"] = retry_instruction.strip()
        except Exception as e:
            print(f"ERROR: Tier 2 API call failed: {e}")
            sys.exit(1)

    # Write score file atomically
    score_path = os.path.join(CAPTURE_DIR, f"{session_key}.score.json")
    tmp_path = score_path + ".tmp"
    try:
        with open(tmp_path, "w", encoding="utf-8") as f:
            json.dump(result, f, indent=2, ensure_ascii=False)
        os.replace(tmp_path, score_path)
    except Exception as e:
        print(f"ERROR: failed to write score file: {e}")
        if os.path.exists(tmp_path):
            os.unlink(tmp_path)
        sys.exit(1)

    sys.exit(0)


if __name__ == "__main__":
    main()
