#!/usr/bin/env python3
"""L1 + L2: 从 Claude Code session jsonl 抽取结构化 turn，按规则筛工作流信号点。

L1 抽取：jsonl → 结构化 turn (role, text, tool_calls, ts, turn_gap_seconds)
         去重 system-reminder / 裁 tool_result 内容 / 敏感词过滤
L2 信号筛：跑规则匹配，输出候选信号点（含 ±3 turn 上下文）

输出：JSON 给 run.py 喂 LLM；也可直接 print --human 看人话调试。
"""
from __future__ import annotations

import argparse
import datetime as dt
import glob
import hashlib
import json
import os
import re
import sys
from pathlib import Path
from typing import Any

# ---- 配置 -----------------------------------------------------------------

RESULT_HEAD_CHARS = 200  # tool_result 只保留前 N 字
TURN_GAP_LONG_SECONDS = 1800  # > 30 min 才算"中断后恢复"级别的 gap（5min 噪声太多）
EDIT_REPEAT_WINDOW = 10  # Edit 反复改判定窗口
EDIT_REPEAT_THRESHOLD = 3
TOOL_ERROR_STREAK = 2
ASK_USER_QUESTION_THRESHOLD = 2
SIGNAL_CONTEXT_TURNS = 3  # 信号点 ±N turn 作为上下文

FRUSTRATION_PATTERNS = [
    r"为什么", r"不对", r"重做", r"还有别的", r"还有其他",
    r"别这样", r"不要这样", r"停一下", r"等等", r"打住",
    r"你看错", r"你理解错", r"再想想",
]

# 敏感词预过滤（命中整 turn 丢弃，不入 L2，不喂 LLM）
REDACT_PATTERNS = [
    r"sk-[a-zA-Z0-9]{32,}",  # API key (Anthropic-style)
    r"ghp_[a-zA-Z0-9]{36,}",  # GitHub PAT
    r"ghs_[a-zA-Z0-9]{36,}",  # GitHub server token
    r"ffai_[A-Za-z0-9_-]{20,}",  # FFAI personal token
    r"eyJ[A-Za-z0-9_-]+\.[A-Za-z0-9_-]+\.[A-Za-z0-9_-]+",  # JWT
    r"-----BEGIN (RSA |EC |OPENSSH |)PRIVATE KEY-----",  # SSH/PEM private key
]

FRUSTRATION_RE = re.compile("|".join(FRUSTRATION_PATTERNS))
REDACT_RE = re.compile("|".join(REDACT_PATTERNS))


# ---- 共享 helper（供 run.py / run_all.py import） --------------------------


def empty_payload(date_str: str) -> dict:
    """空 signals payload — 当 jsonl 缺失 / opt-out / 抽取失败时占位。

    单源定义避免 run.py / run_all.py 多处重复字面量。改字段时只动这里。
    """
    return {"opted_out": False, "date": date_str, "sessions": [], "signals": []}


# ---- L1 抽取 --------------------------------------------------------------


def encode_project_path(repo_root: str) -> str:
    """Claude Code 用 / → - 编码 cwd 路径。"""
    return repo_root.replace("/", "-")


def discover_jsonl_files(home: Path, date_str: str) -> list[Path]:
    """所有项目目录下当日有改动的 jsonl 文件。"""
    proj_root = home / ".claude" / "projects"
    if not proj_root.is_dir():
        return []
    target_date = dt.date.fromisoformat(date_str)
    files: list[Path] = []
    for jsonl in proj_root.glob("*/*.jsonl"):
        try:
            mtime = dt.datetime.fromtimestamp(jsonl.stat().st_mtime).date()
        except OSError:
            continue
        # 文件 mtime 在今天 ± 1 天范围内才扫（容差跨午夜的 session）
        if abs((mtime - target_date).days) <= 1:
            files.append(jsonl)
    return sorted(files)


def args_signature(tool_name: str, tool_input: Any) -> str:
    """为 tool 调用生成简短指纹：name + 关键参数 hash。Edit/Read/Write 用 file_path。"""
    key_args = ""
    if isinstance(tool_input, dict):
        if tool_name in ("Edit", "Read", "Write", "NotebookEdit"):
            key_args = str(tool_input.get("file_path", ""))
        elif tool_name in ("Bash",):
            cmd = str(tool_input.get("command", ""))[:80]
            key_args = cmd
        elif tool_name in ("Grep",):
            key_args = str(tool_input.get("pattern", ""))[:60]
        else:
            key_args = hashlib.md5(
                json.dumps(tool_input, sort_keys=True, default=str).encode()
            ).hexdigest()[:8]
    return f"{tool_name}::{key_args}"


def extract_text_from_content(content: Any) -> str:
    """从 message.content (list of blocks) 抽 plain text，跳过 thinking/tool_use/tool_result。"""
    if isinstance(content, str):
        return content
    if not isinstance(content, list):
        return ""
    parts: list[str] = []
    for blk in content:
        if isinstance(blk, dict) and blk.get("type") == "text":
            t = blk.get("text", "")
            if isinstance(t, str):
                parts.append(t)
    return "\n".join(parts).strip()


def extract_tool_calls(content: Any) -> list[dict]:
    """从 assistant message.content 抽 tool_use blocks。"""
    if not isinstance(content, list):
        return []
    out: list[dict] = []
    for blk in content:
        if isinstance(blk, dict) and blk.get("type") == "tool_use":
            out.append(
                {
                    "id": blk.get("id"),
                    "name": blk.get("name", ""),
                    "input": blk.get("input"),
                    "sig": args_signature(blk.get("name", ""), blk.get("input")),
                }
            )
    return out


def extract_tool_results(content: Any) -> list[dict]:
    """从 user message.content 抽 tool_result blocks。"""
    if not isinstance(content, list):
        return []
    out: list[dict] = []
    for blk in content:
        if isinstance(blk, dict) and blk.get("type") == "tool_result":
            raw_c = blk.get("content")
            # content 可能是 str 或 list[{type:text,text:...}]
            if isinstance(raw_c, list):
                head = ""
                for sub in raw_c:
                    if isinstance(sub, dict) and sub.get("type") == "text":
                        head += sub.get("text", "")
                        if len(head) > RESULT_HEAD_CHARS:
                            break
            else:
                head = str(raw_c or "")
            total_len = len(head)
            head_trunc = head[:RESULT_HEAD_CHARS]
            out.append(
                {
                    "tool_use_id": blk.get("tool_use_id"),
                    "is_error": bool(blk.get("is_error", False)),
                    "result_head": head_trunc,
                    "result_total_len": total_len,
                }
            )
    return out


def is_redacted(text: str) -> bool:
    return bool(REDACT_RE.search(text))


def load_turns(jsonl_path: Path, target_date: str) -> list[dict]:
    """读单个 jsonl → 仅当日 turn list (按时间序)。

    每 turn dict 含: ts / role / text / tool_calls / tool_results / session_id / git_branch / cwd
    """
    turns: list[dict] = []
    try:
        with open(jsonl_path, "r", encoding="utf-8", errors="replace") as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                try:
                    obj = json.loads(line)
                except json.JSONDecodeError:
                    continue
                t = obj.get("type")
                if t not in ("user", "assistant"):
                    continue
                ts_str = obj.get("timestamp")
                if not ts_str:
                    continue
                try:
                    ts = dt.datetime.fromisoformat(ts_str.replace("Z", "+00:00"))
                except ValueError:
                    continue
                if ts.date().isoformat() != target_date:
                    continue
                msg = obj.get("message") or {}
                content = msg.get("content")
                text = extract_text_from_content(content)
                # 敏感词整 turn 丢弃
                if text and is_redacted(text):
                    continue
                if t == "assistant":
                    tool_calls = extract_tool_calls(content)
                    tool_results = []
                else:
                    tool_calls = []
                    tool_results = extract_tool_results(content)
                turns.append(
                    {
                        "ts": ts.isoformat(),
                        "role": t,
                        "text": text,
                        "tool_calls": tool_calls,
                        "tool_results": tool_results,
                        "session_id": obj.get("sessionId"),
                        "git_branch": obj.get("gitBranch"),
                        "cwd": obj.get("cwd"),
                        "is_sidechain": obj.get("isSidechain", False),
                    }
                )
    except OSError:
        return []
    return turns


def compute_turn_gaps(turns: list[dict]) -> None:
    """in-place 填 turn_gap_seconds（跟上一个 turn 的时差）。"""
    prev_ts: dt.datetime | None = None
    for t in turns:
        try:
            cur = dt.datetime.fromisoformat(t["ts"])
        except ValueError:
            t["turn_gap_seconds"] = None
            continue
        if prev_ts is None:
            t["turn_gap_seconds"] = None
        else:
            t["turn_gap_seconds"] = int((cur - prev_ts).total_seconds())
        prev_ts = cur


# ---- L2 信号筛 ------------------------------------------------------------


def l2_signals(turns: list[dict]) -> list[dict]:
    """规则匹配出候选信号点。每个 signal 含 type / turn_index / detail / context_window。"""
    signals: list[dict] = []

    # 1) Edit 同一 file path 在窗口内 ≥ 阈值
    edit_history: list[tuple[int, str]] = []  # (turn_index, file_path)
    for i, t in enumerate(turns):
        for call in t.get("tool_calls", []):
            if call["name"] in ("Edit", "Write", "NotebookEdit"):
                fp = (call.get("input") or {}).get("file_path", "")
                if not fp:
                    continue
                edit_history.append((i, fp))
                # 检查窗口内该 file_path 的重复次数
                window_lo = i - EDIT_REPEAT_WINDOW
                cnt = sum(1 for (idx, f) in edit_history if idx >= window_lo and f == fp)
                if cnt >= EDIT_REPEAT_THRESHOLD:
                    signals.append(
                        {
                            "type": "edit_repeat",
                            "turn_index": i,
                            "detail": {"file_path": fp, "count_in_window": cnt},
                        }
                    )

    # 2) tool_result is_error 连续 ≥ N
    streak = 0
    for i, t in enumerate(turns):
        had_err = any(r.get("is_error") for r in t.get("tool_results", []))
        if had_err:
            streak += 1
            if streak >= TOOL_ERROR_STREAK:
                signals.append(
                    {
                        "type": "tool_error_streak",
                        "turn_index": i,
                        "detail": {"streak": streak},
                    }
                )
        else:
            streak = 0

    # 3) turn_gap > 5 min
    for i, t in enumerate(turns):
        gap = t.get("turn_gap_seconds")
        if gap and gap > TURN_GAP_LONG_SECONDS:
            signals.append(
                {
                    "type": "long_gap",
                    "turn_index": i,
                    "detail": {"gap_seconds": gap},
                }
            )

    # 4) frustration token 在 user message
    for i, t in enumerate(turns):
        if t["role"] != "user":
            continue
        text = t.get("text", "")
        if text and FRUSTRATION_RE.search(text):
            signals.append(
                {
                    "type": "frustration",
                    "turn_index": i,
                    "detail": {"text_head": text[:120]},
                }
            )

    # 5) AskUserQuestion 集群（≥ N 次在 ≤ 5 turn 窗口）
    asks = [i for i, t in enumerate(turns) if any(c["name"] == "AskUserQuestion" for c in t.get("tool_calls", []))]
    for j in range(len(asks)):
        window = [a for a in asks if 0 <= asks[j] - a <= 5]
        if len(window) >= ASK_USER_QUESTION_THRESHOLD:
            signals.append(
                {
                    "type": "ask_cluster",
                    "turn_index": asks[j],
                    "detail": {"count_in_window": len(window)},
                }
            )
            break  # 报第一次集群即可，避免重复

    # 去重 — 按"信号主体"而非位置：
    #  edit_repeat: 同 file_path 仅保留 count 最高的一次
    #  tool_error_streak: 同一连续段仅保留最大 streak
    #  long_gap: 全 session 留 1 个最长的
    #  frustration: 全 session 至多 3 条最显著的
    #  ask_cluster: 全 session 仅 1 条
    by_type: dict[str, list[dict]] = {}
    for s in signals:
        by_type.setdefault(s["type"], []).append(s)

    dedup: list[dict] = []

    # edit_repeat: 同 file_path 取 max count
    er_best: dict[str, dict] = {}
    for s in by_type.get("edit_repeat", []):
        fp = s["detail"]["file_path"]
        if fp not in er_best or s["detail"]["count_in_window"] > er_best[fp]["detail"]["count_in_window"]:
            er_best[fp] = s
    dedup.extend(er_best.values())

    # tool_error_streak: 同一连续段（turn_index 差 ≤ 1）取 max streak
    tes = sorted(by_type.get("tool_error_streak", []), key=lambda x: x["turn_index"])
    if tes:
        groups: list[list[dict]] = [[tes[0]]]
        for s in tes[1:]:
            if s["turn_index"] - groups[-1][-1]["turn_index"] <= 1:
                groups[-1].append(s)
            else:
                groups.append([s])
        for g in groups:
            best = max(g, key=lambda x: x["detail"]["streak"])
            dedup.append(best)

    # long_gap: 取最长一次
    lg = by_type.get("long_gap", [])
    if lg:
        dedup.append(max(lg, key=lambda x: x["detail"]["gap_seconds"]))

    # frustration: 取前 3
    fr = by_type.get("frustration", [])
    dedup.extend(fr[:3])

    # ask_cluster: 取 count 最高一次（已在生成阶段 break，这里再保险）
    ac = by_type.get("ask_cluster", [])
    if ac:
        dedup.append(max(ac, key=lambda x: x["detail"]["count_in_window"]))

    dedup.sort(key=lambda x: x["turn_index"])
    return dedup


def add_context_to_signals(signals: list[dict], turns: list[dict]) -> list[dict]:
    """给每个信号点附 ±N turn 的精简上下文。"""
    out: list[dict] = []
    for s in signals:
        i = s["turn_index"]
        lo = max(0, i - SIGNAL_CONTEXT_TURNS)
        hi = min(len(turns), i + SIGNAL_CONTEXT_TURNS + 1)
        window: list[dict] = []
        for j in range(lo, hi):
            t = turns[j]
            entry = {
                "rel": j - i,  # -3..+3 相对位置
                "role": t["role"],
                "text": (t["text"][:300] + ("…" if len(t["text"]) > 300 else "")) if t["text"] else "",
            }
            if t.get("tool_calls"):
                entry["tools"] = [c["sig"] for c in t["tool_calls"]]
            if t.get("tool_results"):
                entry["results"] = [
                    {
                        "is_error": r.get("is_error"),
                        "head": r["result_head"][:80] if r.get("result_head") else "",
                    }
                    for r in t["tool_results"]
                ]
            window.append(entry)
        out.append({**s, "context": window})
    return out


# ---- 主入口 ---------------------------------------------------------------


def run(date_str: str, home: Path, opt_out_flag: Path | None) -> dict:
    if opt_out_flag and opt_out_flag.exists():
        return {"opted_out": True, "date": date_str, "sessions": [], "signals": []}
    files = discover_jsonl_files(home, date_str)
    all_turns_by_session: dict[str, list[dict]] = {}
    for f in files:
        turns = load_turns(f, date_str)
        if not turns:
            continue
        for t in turns:
            sid = t.get("session_id") or f.stem
            all_turns_by_session.setdefault(sid, []).append(t)
    session_payloads: list[dict] = []
    all_signals: list[dict] = []
    for sid, turns in all_turns_by_session.items():
        turns.sort(key=lambda x: x["ts"])
        compute_turn_gaps(turns)
        sigs = l2_signals(turns)
        sigs_with_ctx = add_context_to_signals(sigs, turns)
        # session 元信息
        session_payloads.append(
            {
                "session_id": sid,
                "turn_count": len(turns),
                "first_ts": turns[0]["ts"],
                "last_ts": turns[-1]["ts"],
                "git_branch": next((t.get("git_branch") for t in turns if t.get("git_branch")), None),
                "cwd": next((t.get("cwd") for t in turns if t.get("cwd")), None),
                "signal_count": len(sigs_with_ctx),
            }
        )
        for s in sigs_with_ctx:
            all_signals.append({**s, "session_id": sid})
    return {
        "opted_out": False,
        "date": date_str,
        "sessions": sorted(session_payloads, key=lambda x: x["first_ts"]),
        "signals": all_signals,
    }


def print_human(payload: dict) -> None:
    if payload.get("opted_out"):
        print("(opted-out via ~/.claude-insight-optout)")
        return
    print(f"=== Date: {payload['date']} ===")
    print(f"Sessions: {len(payload['sessions'])}, total signals: {len(payload['signals'])}")
    for s in payload["sessions"]:
        print(
            f"\n[session {s['session_id'][:8]}] "
            f"branch={s.get('git_branch') or '?'} cwd={(s.get('cwd') or '')[:50]} "
            f"turns={s['turn_count']} signals={s['signal_count']}"
        )
    print("\n--- Signals ---")
    for sig in payload["signals"]:
        print(
            f"  [{sig['type']}] turn#{sig['turn_index']} session={sig['session_id'][:8]} "
            f"detail={sig['detail']}"
        )


def main() -> int:
    p = argparse.ArgumentParser(description="Extract L1+L2 workflow signals from Claude Code jsonl")
    p.add_argument("--date", default=dt.date.today().isoformat(), help="YYYY-MM-DD, default today")
    p.add_argument("--home", default=str(Path.home()), help="user home, default $HOME")
    p.add_argument("--out", help="write JSON to this path (default: stdout)")
    p.add_argument("--human", action="store_true", help="print human-readable summary instead of JSON")
    args = p.parse_args()

    home = Path(args.home)
    opt_out = home / ".claude-insight-optout"
    payload = run(args.date, home, opt_out)

    if args.human:
        print_human(payload)
        return 0

    if args.out:
        Path(args.out).write_text(json.dumps(payload, ensure_ascii=False, indent=2))
        print(f"wrote {args.out} ({len(payload['signals'])} signals, {len(payload['sessions'])} sessions)", file=sys.stderr)
    else:
        json.dump(payload, sys.stdout, ensure_ascii=False, indent=2)
    return 0


if __name__ == "__main__":
    sys.exit(main())
