Zion Boggan
repos/Pitch Tracker CV/tools/analyze_session.py
zionboggan.com ↗
99 lines · python
History for this file →
1
"""Analyze a recorded events JSONL for pitch count, plate_y calibration, PCI stability."""
2
from __future__ import annotations
3
 
4
import json
5
import sys
6
from collections import defaultdict
7
from pathlib import Path
8
 
9
import numpy as np
10
from rich.console import Console
11
 
12
console = Console()
13
 
14
def load_events(path: Path) -> list[dict]:
15
    out: list[dict] = []
16
    with open(path, "r", encoding="utf-8") as f:
17
        for line in f:
18
            line = line.strip()
19
            if not line:
20
                continue
21
            try:
22
                out.append(json.loads(line))
23
            except Exception:
24
                pass
25
    return out
26
 
27
def analyze(events: list[dict]) -> None:
28
    ball_tracks = [e for e in events if e.get("type") == "ball_track"]
29
    ball_misses = [e for e in events if e.get("type") == "ball_miss"]
30
    pitch_preds = [e for e in events if e.get("type") == "pitch_pred"]
31
    pci_tracks = [e for e in events if e.get("type") == "pci_track"]
32
    pci_misses = [e for e in events if e.get("type") == "pci_miss"]
33
 
34
    total_ball = len(ball_tracks) + len(ball_misses)
35
    total_pci = len(pci_tracks) + len(pci_misses)
36
    console.print("[bold]Session totals[/bold]")
37
    console.print(f"  ball: {len(ball_tracks)}/{total_ball} hits ({100*len(ball_tracks)/max(total_ball,1):.1f}%), {len(pitch_preds)} predictions")
38
    console.print(f"  pci:  {len(pci_tracks)}/{total_pci} hits ({100*len(pci_tracks)/max(total_pci,1):.1f}%)")
39
 
40
    pitches: dict[int, list[dict]] = defaultdict(list)
41
    for e in ball_tracks:
42
        pid = int(e.get("pitch_id", 0))
43
        pitches[pid].append(e)
44
    console.print(f"\n[bold]Pitches detected[/bold]: {len(pitches)}")
45
 
46
    max_ys = []
47
    preds_per_pitch: dict[int, list[dict]] = defaultdict(list)
48
    for pp in pitch_preds:
49
        preds_per_pitch[int(pp.get("pitch_id", 0))].append(pp)
50
    for pid in sorted(pitches.keys()):
51
        pts = pitches[pid]
52
        if len(pts) < 2:
53
            continue
54
        xs = np.array([p["x"] for p in pts])
55
        ys = np.array([p["y"] for p in pts])
56
        ts = np.array([p["ts_ns"] for p in pts])
57
        dur_ms = (ts[-1] - ts[0]) / 1e6
58
        max_y = float(ys.max())
59
        y_delta = float(ys.max() - ys.min())
60
        pred_n = len(preds_per_pitch.get(pid, []))
61
        console.print(
62
            f"  pitch #{pid:3d}  n={len(pts):3d}  dur={dur_ms:6.0f}ms  "
63
            f"x=({xs.min():.0f}..{xs.max():.0f})  y=({ys.min():.0f}..{ys.max():.0f})  "
64
            f"max_y={max_y:.0f}  dy={y_delta:.0f}  preds={pred_n}"
65
        )
66
        if y_delta >= 100:
67
            max_ys.append(max_y)
68
 
69
    console.print("\n[bold]plate_y calibration[/bold]")
70
    if max_ys:
71
        arr = np.array(max_ys)
72
        console.print(f"  real pitches (dy>=100px): {len(arr)}")
73
        console.print(f"  max_y stats: mean={arr.mean():.1f}  p50={np.percentile(arr,50):.1f}  p95={np.percentile(arr,95):.1f}  max={arr.max():.1f}")
74
        rec = float(np.percentile(arr, 90)) / 1080.0
75
        console.print(f"  [green]suggested plate_y_frac ~= {rec:.3f}[/green]  (was 0.72)")
76
    else:
77
        console.print("  no pitches with >=100px y-delta - probably none captured or trajectories too short")
78
 
79
    console.print("\n[bold]PCI centroid stability[/bold]")
80
    if pci_tracks:
81
        xs = np.array([p["x"] for p in pci_tracks])
82
        ys = np.array([p["y"] for p in pci_tracks])
83
        console.print(f"  samples: {len(pci_tracks)}")
84
        console.print(f"  x: mean={xs.mean():.0f}  std={xs.std():.0f}  range=({xs.min():.0f}..{xs.max():.0f})")
85
        console.print(f"  y: mean={ys.mean():.0f}  std={ys.std():.0f}  range=({ys.min():.0f}..{ys.max():.0f})")
86
    else:
87
        console.print("  no PCI hits recorded")
88
 
89
def main() -> int:
90
    if len(sys.argv) < 2:
91
        console.print("usage: analyze_session.py <events.jsonl>")
92
        return 2
93
    events = load_events(Path(sys.argv[1]))
94
    console.print(f"loaded {len(events)} events from {sys.argv[1]}")
95
    analyze(events)
96
    return 0
97
 
98
if __name__ == "__main__":
99
    sys.exit(main())