#!/usr/bin/env python3
"""
FER Monitor (Fuel Efficiency Reader)
Runs every 5 minutes via LaunchAgent. Zero LLM involvement.

Modes:
  observe  - first 7 days. Log only. One hard circuit breaker at $2.00/5min totalCost delta.
  enforce  - after calibration. Yellow/red thresholds from fer-calibrate.py.
"""

import json
import os
import subprocess
import sys
import urllib.request
import urllib.error
import glob
from datetime import datetime, timezone, timedelta

# --- Config ---
STATE_FILE          = os.path.expanduser("~/.openclaw/workspace/memory/fer-state.json")
LOG_FILE            = os.path.expanduser("~/.openclaw/workspace/memory/fer-log.jsonl")
WIDGET_FILE         = os.path.expanduser("~/.openclaw/workspace/memory/vital-widget.json")
WIDGET_FILE_ROOT    = os.path.expanduser("~/.openclaw/workspace/vital-widget.json")
BALANCE_ANCHOR_FILE = os.path.expanduser("~/.openclaw/workspace/memory/vital-balance-anchor.json")
ALERT_TARGET        = "+15406208059"
OBSERVE_DAYS        = 7

# Observe-mode hard circuit breaker: $2.00 output cost in one 5-min window.
# CacheWrite is excluded (front-loaded at session start, not a runaway signal).
CIRCUIT_BREAKER_OUTPUT = 2.00

# Daily budget for the iPhone gas-gauge widget (edit as needed).
DAILY_BUDGET = 8.00

# Lookback window for spend-rate calculation (number of 5-min intervals = last 30 min).
RATE_WINDOW = 6


# --- Helpers ---

def get_usage(days=1):
    try:
        result = subprocess.run(
            ["openclaw", "gateway", "usage-cost", "--json", "--days", str(days)],
            capture_output=True, text=True, timeout=15
        )
        if result.returncode != 0:
            return None
        return json.loads(result.stdout)
    except Exception:
        return None


def fetch_anthropic_usage(api_key, date_str):
    """
    Fetch total cost for date_str (YYYY-MM-DD) from the Anthropic Usage API.
    Requires an Admin API key (sk-ant-admin...).  Returns float cost or None on failure.
    The org usage API returns token counts per model; we compute cost from known pricing.
    Falls back gracefully — never raises.
    """
    try:
        # Anthropic's usage report endpoint (requires Admin API key)
        next_day = (datetime.strptime(date_str, "%Y-%m-%d") + timedelta(days=1)).strftime("%Y-%m-%d")
        url = (
            f"https://api.anthropic.com/v1/organizations/usage_report/messages"
            f"?starting_at={date_str}T00:00:00Z"
            f"&ending_at={next_day}T00:00:00Z"
            f"&bucket_width=1d"
        )
        req = urllib.request.Request(url, headers={
            "x-api-key": api_key,
            "anthropic-version": "2023-06-01",
            "User-Agent": "fer-monitor/1.0",
        })
        with urllib.request.urlopen(req, timeout=15) as resp:
            data = json.loads(resp.read())

        # Sum costs across all model buckets
        # Response: {"data": [{"model": "...", "usage": {"input_tokens": N, "output_tokens": N, ...}, "cost": {"total": X}}]}
        total_cost = 0.0
        for entry in data.get("data", []):
            cost = entry.get("cost", {})
            if isinstance(cost, dict):
                total_cost += cost.get("total", 0.0)
            # Also handle raw token counts if no cost field
            usage = entry.get("usage", {})
            if not cost and usage:
                # Approximate with claude-3-5-sonnet pricing as fallback
                inp   = usage.get("input_tokens", 0) / 1_000_000 * 3.0
                out   = usage.get("output_tokens", 0) / 1_000_000 * 15.0
                cr    = usage.get("cache_read_input_tokens", 0) / 1_000_000 * 0.30
                cw    = usage.get("cache_creation_input_tokens", 0) / 1_000_000 * 3.75
                total_cost += inp + out + cr + cw
        return round(total_cost, 6)

    except urllib.error.HTTPError as e:
        import sys as _sys
        print(f"[FER] Anthropic Usage API HTTP error {e.code}: {e.reason} "
              f"(admin key required — got regular API key?)", file=_sys.stderr)
        return None
    except Exception as e:
        import sys as _sys
        print(f"[FER] Anthropic Usage API unavailable: {e}", file=_sys.stderr)
        return None


def fetch_session_cost_today(date_str=None):
    """
    Sum cost.total from all OpenClaw session JSONL files (active + deleted + archived)
    modified on date_str (YYYY-MM-DD, defaults to today).
    This captures sub-agent sessions that openclaw gateway usage-cost misses.
    Returns float total cost or None on error.
    """
    if date_str is None:
        date_str = datetime.now().strftime("%Y-%m-%d")
    try:
        session_dir = os.path.expanduser("~/.openclaw/agents/main/sessions/")
        total = 0.0
        seen_files = 0
        for pattern in ["*.jsonl", "*.jsonl.deleted.*", "*archived*.jsonl"]:
            for f in glob.glob(os.path.join(session_dir, pattern)):
                try:
                    mtime = os.path.getmtime(f)
                    mdate = datetime.fromtimestamp(mtime).strftime("%Y-%m-%d")
                    if mdate != date_str:
                        continue
                    seen_files += 1
                    with open(f) as fh:
                        for line in fh:
                            line = line.strip()
                            if not line:
                                continue
                            try:
                                msg = json.loads(line)
                                cost = msg.get("message", {}).get("usage", {}).get("cost", {})
                                if cost and isinstance(cost, dict):
                                    total += cost.get("total", 0.0)
                            except Exception:
                                pass
                except Exception:
                    pass
        if seen_files == 0:
            return None
        return round(total, 6)
    except Exception:
        return None


def load_balance_anchor():
    """Load vital-balance-anchor.json. Returns dict or None."""
    try:
        with open(BALANCE_ANCHOR_FILE) as f:
            return json.load(f)
    except Exception:
        return None


def compute_remaining_balance(anchor, usage_30d, today_total_cost, today_date):
    """
    Compute remaining Anthropic balance given the anchor snapshot.

    anchor keys: anchor_balance, anchor_date, anchor_spend_at_set
    usage_30d: result of get_usage(30), has .daily list
    Returns (remaining_balance, spend_since_anchor, balance_gauge_pct)
    """
    if not anchor:
        return None, None, None

    anchor_balance  = anchor.get("anchor_balance", 0.0)
    anchor_date     = anchor.get("anchor_date", today_date)
    anchor_spend_at = anchor.get("anchor_spend_at_set", 0.0)

    daily = []
    if usage_30d:
        daily = usage_30d.get("daily", [])

    if anchor_date == today_date:
        # Same day: simple delta
        spend_since_anchor = max(0.0, today_total_cost - anchor_spend_at)
    else:
        # Multi-day: sum spend from anchor_date onward
        spend_since_anchor = 0.0
        for d in daily:
            ddate = d.get("date", "")
            if ddate < anchor_date:
                continue
            if ddate == anchor_date:
                # Only count spend after anchor was set
                spend_since_anchor += max(0.0, d["totalCost"] - anchor_spend_at)
            else:
                spend_since_anchor += d["totalCost"]

    remaining_balance = round(anchor_balance - spend_since_anchor, 4)
    balance_gauge_pct = round(max(0.0, remaining_balance / anchor_balance * 100), 1) if anchor_balance > 0 else 0.0
    return remaining_balance, round(spend_since_anchor, 4), balance_gauge_pct


def load_state():
    if os.path.exists(STATE_FILE):
        with open(STATE_FILE) as f:
            return json.load(f)
    return {
        "mode": "observe",
        "observe_start": None,
        "last_checkpoint": {},
        "thresholds": {
            "circuit_breaker_output": CIRCUIT_BREAKER_OUTPUT
        },
        "alert_count": 0
    }


def save_state(state):
    with open(STATE_FILE, "w") as f:
        json.dump(state, f, indent=2)


def send_alert(msg):
    try:
        subprocess.run(
            ["openclaw", "message", "send",
             "--channel", "signal",
             "--target", ALERT_TARGET,
             "--message", msg],
            timeout=15
        )
    except Exception as e:
        # Can't alert via Signal — write to a local alert file as fallback
        with open(os.path.expanduser("~/.openclaw/workspace/memory/fer-alerts.txt"), "a") as f:
            f.write(f"{datetime.now().isoformat()} | SIGNAL FAILED | {msg}\n")


def append_log(entry):
    with open(LOG_FILE, "a") as f:
        f.write(json.dumps(entry) + "\n")


def compute_spend_rate(today_date):
    """
    Reads the last RATE_WINDOW entries from the log for today and returns
    (tokens_per_hour, cost_per_hour) based on the cumulative delta over that window.
    """
    if not os.path.exists(LOG_FILE):
        return 0, 0.0
    try:
        entries = []
        with open(LOG_FILE) as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                try:
                    e = json.loads(line)
                    if e.get("date") == today_date and e.get("event") != "gateway_unreachable":
                        entries.append(e)
                except Exception:
                    pass
        # Use last RATE_WINDOW entries
        window = entries[-RATE_WINDOW:] if len(entries) >= 2 else entries
        if len(window) < 2:
            return 0, 0.0
        # Sum deltas over window
        total_delta_tokens = 0
        total_delta_cost   = 0.0
        for e in window:
            total_delta_tokens += e.get("today_tokens", 0) - (window[0].get("today_tokens", 0) if e is window[0] else 0)
            total_delta_cost   += e.get("delta_total_cost", 0.0)
        # Simpler: first vs last in window
        first = window[0]
        last  = window[-1]
        try:
            t_first = datetime.fromisoformat(first["timestamp"])
            t_last  = datetime.fromisoformat(last["timestamp"])
            elapsed_hours = max((t_last - t_first).total_seconds() / 3600, 1/720)
        except Exception:
            elapsed_hours = len(window) * 5 / 60
        token_delta = max(last.get("today_tokens", 0) - first.get("today_tokens", 0), 0)
        cost_delta  = max(last.get("today_total_cost", 0.0) - first.get("today_total_cost", 0.0), 0.0)
        tokens_per_hour = token_delta / elapsed_hours
        cost_per_hour   = cost_delta  / elapsed_hours
        return int(tokens_per_hour), round(cost_per_hour, 4)
    except Exception:
        return 0, 0.0


def fetch_metals_prices():
    """Fetch gold and silver spot prices from gold-api.com. Returns (gold, silver) or (None, None)."""
    try:
        gold, silver = None, None
        for symbol, name in [("XAU", "gold"), ("XAG", "silver")]:
            url = f"https://api.gold-api.com/price/{symbol}"
            req = urllib.request.Request(url, headers={"User-Agent": "vital-widget/3.0"})
            with urllib.request.urlopen(req, timeout=8) as resp:
                d = json.loads(resp.read())
            price = round(float(d["price"]), 2)
            if name == "gold":
                gold = price
            else:
                silver = price
        return gold, silver
    except Exception:
        return None, None


def fetch_xmr_price():
    """Fetch Monero (XMR) price in USD from CoinGecko. Returns float or None."""
    try:
        url = "https://api.coingecko.com/api/v3/simple/price?ids=monero&vs_currencies=usd"
        req = urllib.request.Request(url, headers={"User-Agent": "vital-widget/3.0"})
        with urllib.request.urlopen(req, timeout=8) as resp:
            data = json.loads(resp.read())
        return round(float(data["monero"]["usd"]), 2)
    except Exception:
        return None


def compute_last_hour_spend():
    """Sum delta_total_cost from log entries in the past 60 minutes."""
    if not os.path.exists(LOG_FILE):
        return None
    try:
        cutoff = datetime.now(timezone.utc) - timedelta(hours=1)
        total = 0.0
        with open(LOG_FILE) as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                try:
                    e = json.loads(line)
                    ts = datetime.fromisoformat(e.get("timestamp", ""))
                    if ts.tzinfo is None:
                        ts = ts.replace(tzinfo=timezone.utc)
                    if ts >= cutoff:
                        total += e.get("delta_total_cost", 0.0)
                except Exception:
                    pass
        return round(total, 6)
    except Exception:
        return None


def compute_avg_24h_hourly():
    """Compute average hourly cost over the last 24 hours from log."""
    if not os.path.exists(LOG_FILE):
        return None
    try:
        cutoff = datetime.now(timezone.utc) - timedelta(hours=24)
        entries = []
        with open(LOG_FILE) as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                try:
                    e = json.loads(line)
                    ts = datetime.fromisoformat(e.get("timestamp", ""))
                    if ts.tzinfo is None:
                        ts = ts.replace(tzinfo=timezone.utc)
                    if ts >= cutoff:
                        entries.append(e)
                except Exception:
                    pass
        if len(entries) < 2:
            return None
        total_spend = sum(e.get("delta_total_cost", 0.0) for e in entries)
        first_ts = datetime.fromisoformat(entries[0]["timestamp"])
        last_ts  = datetime.fromisoformat(entries[-1]["timestamp"])
        if first_ts.tzinfo is None:
            first_ts = first_ts.replace(tzinfo=timezone.utc)
        if last_ts.tzinfo is None:
            last_ts = last_ts.replace(tzinfo=timezone.utc)
        elapsed_hours = max((last_ts - first_ts).total_seconds() / 3600, 0.25)
        return round(total_spend / elapsed_hours, 6)
    except Exception:
        return None


def build_ticker_string():
    """Build a ticker string from TODO.md and SCHEDULE.md (next 3 days)."""
    workspace = os.path.expanduser("~/.openclaw/workspace")
    todo_file     = os.path.join(workspace, "TODO.md")
    schedule_file = os.path.join(workspace, "SCHEDULE.md")
    parts = []

    # TODO items
    try:
        with open(todo_file) as f:
            for line in f:
                line = line.strip()
                if line.startswith("- "):
                    item = line[2:].strip()
                    if item:
                        parts.append("TODO: " + item)
    except Exception:
        pass

    # SCHEDULE: next 3 days
    try:
        today = datetime.now().date()
        cutoff = today + timedelta(days=3)
        with open(schedule_file) as f:
            for line in f:
                line = line.strip()
                if not line or line.startswith("#"):
                    continue
                # Format: YYYY-MM-DD HH:MM -- Description
                try:
                    parts_line = line.split(" -- ", 1)
                    date_str = parts_line[0].split()[0]
                    entry_date = datetime.strptime(date_str, "%Y-%m-%d").date()
                    if today <= entry_date <= cutoff:
                        desc = parts_line[1].strip() if len(parts_line) > 1 else line
                        parts.append("SCHEDULE: " + desc)
                except Exception:
                    pass
    except Exception:
        pass

    return " | ".join(parts) if parts else ""


def classify_day_type(cost_per_hour, baseline_cost_per_hour):
    """
    Returns (day_type, ratio) where day_type is 'light' / 'normal' / 'heavy' / 'unknown'.
    ratio = current_rate / baseline_rate (1.0 = exactly baseline).
    """
    if baseline_cost_per_hour is None or baseline_cost_per_hour <= 0:
        return "unknown", None
    ratio = cost_per_hour / baseline_cost_per_hour
    if ratio < 0.5:
        day_type = "light"
    elif ratio > 1.5:
        day_type = "heavy"
    else:
        day_type = "normal"
    return day_type, round(ratio, 2)


def write_widget_json(state, today_total_cost, today_output_cost, today_tokens,
                      severity, now_iso,
                      remaining_balance=None, spend_since_anchor=None,
                      balance_gauge_pct=None, anchor_date=None):
    """Write vital-widget.json for the Scriptable iPhone widget to read."""
    today_date = datetime.now().strftime("%Y-%m-%d")
    tokens_per_hour, cost_per_hour = compute_spend_rate(today_date)

    # Baseline: stored in state after calibration, or None during observe
    baseline = state.get("thresholds", {}).get("baseline_cost_per_hour", None)
    day_type, ratio = classify_day_type(cost_per_hour, baseline)

    budget = DAILY_BUDGET
    spent_pct  = round((today_total_cost / budget) * 100, 1) if budget > 0 else 0.0
    remain_pct = max(0.0, round(100 - spent_pct, 1))

    # --- New v3.0 fields ---
    gold_price, silver_price = fetch_metals_prices()
    xmr_usd      = fetch_xmr_price()
    last_hour_spend  = compute_last_hour_spend()
    avg_24h_hourly   = compute_avg_24h_hourly()
    widget = {
        "schema":               "3.0",
        "last_updated":         now_iso,
        "mode":                 state.get("mode", "observe"),
        "days_observed":        state.get("days_observed_snapshot", 0),
        "daily_budget":         budget,
        "today_total_cost":     round(today_total_cost, 4),
        "today_output_cost":    round(today_output_cost, 4),
        "today_tokens":         today_tokens,
        "tokens_per_hour":      tokens_per_hour,
        "cost_per_hour":        cost_per_hour,
        "last_hour_spend":      last_hour_spend,
        "avg_24h_hourly":       avg_24h_hourly,
        "day_type":             day_type,
        "day_type_ratio":       ratio,
        "gauge_pct":            remain_pct,
        "spent_pct":            spent_pct,
        # --- Anthropic balance fields ---
        "remaining_balance":    remaining_balance,
        "spend_since_anchor":   spend_since_anchor,
        "balance_gauge_pct":    balance_gauge_pct,
        "anchor_date":          anchor_date,
        # --- Market prices ---
        "gold_price":           gold_price,
        "silver_price":         silver_price,
        "xmr_usd":              xmr_usd,
        # --------------------------------
        "circuit_breaker_output": state.get("thresholds", {}).get("circuit_breaker_output", CIRCUIT_BREAKER_OUTPUT),
        "alert_count":          state.get("alert_count", 0),
        "severity":             severity
    }
    with open(WIDGET_FILE, "w") as f:
        json.dump(widget, f, indent=2)
    with open(WIDGET_FILE_ROOT, "w") as f2:
        json.dump(widget, f2, indent=2)


# --- Main ---

def main():
    now       = datetime.now(timezone.utc)   # for timestamps in log
    today_dt  = datetime.now()               # local time — usage-cost uses local dates
    today     = today_dt.strftime("%Y-%m-%d")

    # --- Primary cost source: OpenClaw session files (captures all sub-agent sessions) ---
    # Admin API key (sk-ant-admin...) would unlock Anthropic's Usage API for even higher accuracy.
    # For now, session-file aggregation is far more accurate than gateway usage-cost alone.
    ANTHROPIC_API_KEY = None
    try:
        auth_file = os.path.expanduser("~/.openclaw/agents/main/agent/auth-profiles.json")
        with open(auth_file) as _f:
            _auth = json.load(_f)
        ANTHROPIC_API_KEY = _auth.get("profiles", {}).get("anthropic:default", {}).get("key")
    except Exception:
        pass

    # Try Anthropic Usage API first (requires admin key, will fail gracefully with regular key)
    today_total_from_api = None
    if ANTHROPIC_API_KEY:
        today_total_from_api = fetch_anthropic_usage(ANTHROPIC_API_KEY, today)
        if today_total_from_api is not None:
            import sys as _sys
            print(f"[FER] Anthropic Usage API: ${today_total_from_api:.4f}", file=_sys.stderr)

    # Try session file aggregation (always available, much more accurate than gateway)
    today_total_from_sessions = fetch_session_cost_today(today)

    # Fall back to gateway usage-cost as last resort
    usage = get_usage(days=1)
    if not usage and today_total_from_sessions is None and today_total_from_api is None:
        # All sources failed — gateway may be down
        append_log({
            "timestamp": now.isoformat(),
            "event": "gateway_unreachable",
            "date": today
        })
        return

    # Find today's gateway entry for outputCost and totalTokens (structure still useful)
    daily = usage.get("daily", []) if usage else []
    today_data = next((d for d in daily if d["date"] == today), None)

    # Pick best total cost source: API > sessions > gateway
    if today_total_from_api is not None:
        today_total = today_total_from_api
        _cost_source = "anthropic_api"
    elif today_total_from_sessions is not None:
        today_total = today_total_from_sessions
        _cost_source = "session_files"
        import sys as _sys
        print(f"[FER] Using session-file cost: ${today_total:.4f} "
              f"(gateway reported: ${today_data['totalCost'] if today_data else 'n/a'})", file=_sys.stderr)
    elif today_data:
        today_total = today_data["totalCost"]
        _cost_source = "gateway"
        import sys as _sys
        print(f"[FER] WARNING: Falling back to gateway usage-cost (may undercount sub-agents)", file=_sys.stderr)
    else:
        # No data at all — skip this run (can happen just after midnight)
        return

    # outputCost and totalTokens: use gateway data if available, else estimate
    if today_data:
        today_output = today_data["outputCost"]
        today_tokens = today_data["totalTokens"]
    else:
        today_output = today_total * 0.5  # rough estimate if gateway is down
        today_tokens = 0

    state = load_state()

    # Set observe_start if not set
    if not state.get("observe_start"):
        state["observe_start"] = today

    # Determine days observed
    try:
        obs_start = datetime.strptime(state["observe_start"], "%Y-%m-%d").replace(tzinfo=timezone.utc)
        days_observed = (now - obs_start).days
    except Exception:
        days_observed = 0

    # Load last checkpoint
    cp = state.get("last_checkpoint", {})
    last_date   = cp.get("date", "")
    last_total  = cp.get("total_cost", 0.0) if last_date == today else 0.0
    last_output = cp.get("output_cost", 0.0) if last_date == today else 0.0

    # Deltas since last checkpoint
    delta_total  = max(0.0, today_total  - last_total)
    delta_output = max(0.0, today_output - last_output)

    mode = state.get("mode", "observe")

    # --- Threshold check ---
    alert_msg = None
    severity  = None

    if mode == "observe":
        cb = state.get("thresholds", {}).get("circuit_breaker_output", CIRCUIT_BREAKER_OUTPUT)
        if delta_output >= cb:
            severity  = "CIRCUIT_BREAKER"
            alert_msg = (
                f"[FER] Runaway spend detected. Output cost: ${delta_output:.3f} in ~5 min "
                f"(circuit breaker: ${cb:.2f}). Today total: ${today_total:.2f}. "
                f"Check active sessions immediately."
            )
    elif mode == "enforce":
        thresholds = state.get("thresholds", {})
        yellow = thresholds.get("yellow_output", None)
        red    = thresholds.get("red_output", None)
        if red and delta_output >= red:
            severity  = "RED"
            alert_msg = (
                f"[FER RED] High spend rate: ${delta_output:.3f} output cost in ~5 min "
                f"(threshold: ${red:.3f}). Today total: ${today_total:.2f}. "
                f"Stopping may be needed - check active sessions."
            )
        elif yellow and delta_output >= yellow:
            severity  = "YELLOW"
            alert_msg = (
                f"[FER YELLOW] Elevated spend: ${delta_output:.3f} output cost in ~5 min "
                f"(threshold: ${yellow:.3f}). Today total: ${today_total:.2f}."
            )

    if alert_msg:
        send_alert(alert_msg)
        state["alert_count"] = state.get("alert_count", 0) + 1

    # --- Log entry ---
    append_log({
        "timestamp":        now.isoformat(),
        "date":             today,
        "mode":             mode,
        "days_observed":    days_observed,
        "delta_total_cost": round(delta_total,  6),
        "delta_output_cost": round(delta_output, 6),
        "today_total_cost": round(today_total,  4),
        "today_output_cost": round(today_output, 4),
        "today_tokens":     today_tokens,
        "severity":         severity
    })

    # --- Compute Anthropic balance ---
    # Use scraped balance directly if available and recent (within 2 hours)
    anchor = load_balance_anchor()
    scraped_balance = None
    if anchor:
        last_checked = anchor.get("last_checked", "")
        try:
            import datetime as _dt
            lc = _dt.datetime.fromisoformat(last_checked.replace("Z", "+00:00"))
            age_hours = (datetime.now(timezone.utc) - lc).total_seconds() / 3600
            if age_hours <= 2:
                scraped_balance = anchor.get("anchor_balance")
        except Exception:
            pass

    anchor_date_str = anchor.get("anchor_date") if anchor else None
    usage_30d = None
    if anchor and anchor_date_str and anchor_date_str < today:
        usage_30d = get_usage(days=30)
    remaining_balance, spend_since_anchor, balance_gauge_pct = compute_remaining_balance(
        anchor, usage_30d, today_total, today
    )

    # --- Write widget JSON for iPhone ---
    state["days_observed_snapshot"] = days_observed
    write_widget_json(
        state,
        today_total,
        today_output,
        today_tokens,
        severity,
        now.isoformat(),
        remaining_balance=remaining_balance,
        spend_since_anchor=spend_since_anchor,
        balance_gauge_pct=balance_gauge_pct,
        anchor_date=anchor_date_str,
    )

    # --- Update checkpoint ---
    state["last_checkpoint"] = {
        "date":         today,
        "total_cost":   today_total,
        "output_cost":  today_output,
        "total_tokens": today_tokens,
        "timestamp":    now.isoformat()
    }

    # --- Auto-switch to enforce after calibration ---
    if mode == "observe" and days_observed >= OBSERVE_DAYS:
        # Check if calibration has been run
        thresholds = state.get("thresholds", {})
        if thresholds.get("calibrated"):
            state["mode"] = "enforce"
            send_alert(
                "[FER] Observation period complete. Switching to enforce mode. "
                f"Yellow: ${thresholds.get('yellow_output', '?'):.4f}/5min, "
                f"Red: ${thresholds.get('red_output', '?'):.4f}/5min."
            )

    save_state(state)


if __name__ == "__main__":
    main()
