Zion Boggan
repos/Pitch Tracker CV/cv/pci_tracker.py
zionboggan.com ↗
170 lines · python
History for this file →
1
"""PCI tracker.
2
 
3
Finds the Plate Coverage Indicator (the green circle in MLB The Show Zone+PCI
4
batting). Uses template matching if configs/templates/pci_circle.png exists;
5
otherwise falls back to HSV segmentation using the green range in runtime.yaml.
6
 
7
Emits pci_track events each frame (x, y, r, score, method).
8
"""
9
from __future__ import annotations
10
 
11
import argparse
12
import sys
13
import time
14
from pathlib import Path
15
 
16
import cv2
17
import numpy as np
18
from rich.console import Console
19
 
20
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
21
from cv._common import (
22
    iter_latest_frames,
23
    load_config,
24
    make_frame_subscriber,
25
    make_pub,
26
    send_event,
27
)
28
 
29
console = Console()
30
 
31
def load_template(path: Path) -> np.ndarray | None:
32
    if not path.exists():
33
        return None
34
    tpl = cv2.imread(str(path), cv2.IMREAD_COLOR)
35
    if tpl is None:
36
        console.print(f"[yellow]PCI template at {path} could not be read.[/yellow]")
37
        return None
38
    return tpl
39
 
40
def detect_by_template(frame: np.ndarray, tpl: np.ndarray, threshold: float):
41
    res = cv2.matchTemplate(frame, tpl, cv2.TM_CCOEFF_NORMED)
42
    _, max_val, _, max_loc = cv2.minMaxLoc(res)
43
    if max_val < threshold:
44
        return None
45
    th, tw = tpl.shape[:2]
46
    x = max_loc[0] + tw / 2.0
47
    y = max_loc[1] + th / 2.0
48
    r = 0.5 * min(tw, th)
49
    return {"x": float(x), "y": float(y), "r": float(r), "score": float(max_val), "method": "template"}
50
 
51
def detect_by_hsv(frame: np.ndarray, cfg_pci: dict):
52
    """Centroid-based PCI detection.
53
 
54
    The PCI in MLB 26 renders as a multi-part shape (brackets + inner + center),
55
    so fitting a single contour circle doesn't work reliably. Instead we mask
56
    the configured green range, restrict to the central strike-zone region,
57
    and return the centroid of all matching pixels as the PCI position.
58
    Radius is estimated from the spread of matching pixels.
59
    """
60
    hsv_low = np.array(cfg_pci["hsv_low"], dtype=np.uint8)
61
    hsv_high = np.array(cfg_pci["hsv_high"], dtype=np.uint8)
62
    min_px = int(cfg_pci.get("min_green_pixels", 150))
63
    x_lo_f = float(cfg_pci.get("search_x_min_frac", 0.3))
64
    x_hi_f = float(cfg_pci.get("search_x_max_frac", 0.7))
65
    y_lo_f = float(cfg_pci.get("search_y_min_frac", 0.3))
66
    y_hi_f = float(cfg_pci.get("search_y_max_frac", 0.8))
67
 
68
    h, w = frame.shape[:2]
69
    x0, x1 = int(w * x_lo_f), int(w * x_hi_f)
70
    y0, y1 = int(h * y_lo_f), int(h * y_hi_f)
71
 
72
    hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)
73
    mask = cv2.inRange(hsv, hsv_low, hsv_high)
74
 
75
    window = np.zeros_like(mask)
76
    window[y0:y1, x0:x1] = 255
77
    mask = cv2.bitwise_and(mask, window)
78
    mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, np.ones((3, 3), np.uint8), iterations=1)
79
 
80
    ys, xs = np.where(mask > 0)
81
    n = len(xs)
82
    if n < min_px:
83
        return None
84
 
85
    cx = float(xs.mean())
86
    cy = float(ys.mean())
87
 
88
    dx = xs - cx
89
    dy = ys - cy
90
    dists = np.sqrt(dx * dx + dy * dy)
91
    r = float(np.percentile(dists, 80))
92
 
93
    x_bb_min, x_bb_max = int(xs.min()), int(xs.max())
94
    y_bb_min, y_bb_max = int(ys.min()), int(ys.max())
95
    bbox_area = max(1, (x_bb_max - x_bb_min + 1) * (y_bb_max - y_bb_min + 1))
96
    score = float(n / bbox_area)
97
 
98
    return {"x": cx, "y": cy, "r": r, "score": score, "method": "hsv_centroid", "n_px": n}
99
 
100
def main() -> int:
101
    ap = argparse.ArgumentParser(description="PCI tracker.")
102
    ap.add_argument("--duration", type=float, default=0.0, help="Stop after N seconds (0 = run forever).")
103
    ap.add_argument("--quiet", action="store_true")
104
    args = ap.parse_args()
105
 
106
    cfg = load_config()
107
    capture_ep = cfg["capture"]["publish_endpoint"]
108
    pci_ep = cfg["cv"]["pci_events_endpoint"]
109
    pci_cfg = cfg["cv"]["pci"]
110
 
111
    tpl_path = Path(__file__).resolve().parents[1] / pci_cfg["template"]
112
    tpl = load_template(tpl_path)
113
    tpl_thresh = float(pci_cfg.get("conf_threshold", 0.7))
114
 
115
    sub = make_frame_subscriber(capture_ep)
116
    pub = make_pub(pci_ep)
117
    method_label = "template" if tpl is not None else "hsv"
118
    console.print(
119
        f"[green]pci_tracker[/green] sub={capture_ep} pub={pci_ep}  method={method_label}"
120
        + ("" if tpl is not None else f"  (no template at {tpl_path.name}; using HSV fallback)")
121
    )
122
 
123
    t_end = time.perf_counter() + args.duration if args.duration > 0 else None
124
    frames = 0
125
    hits = 0
126
    t_report = time.perf_counter()
127
 
128
    try:
129
        for meta, frame in iter_latest_frames(sub, timeout_ms=3000):
130
            if t_end is not None and time.perf_counter() >= t_end:
131
                break
132
            frames += 1
133
            det = None
134
            if tpl is not None:
135
                det = detect_by_template(frame, tpl, tpl_thresh)
136
            if det is None:
137
                det = detect_by_hsv(frame, pci_cfg)
138
            if det is not None:
139
                hits += 1
140
                event = {
141
                    "type": "pci_track",
142
                    "seq": int(meta["seq"]),
143
                    "ts_ns": int(meta["ts_ns"]),
144
                    **det,
145
                }
146
                send_event(pub, event)
147
            else:
148
                send_event(pub, {
149
                    "type": "pci_miss",
150
                    "seq": int(meta["seq"]),
151
                    "ts_ns": int(meta["ts_ns"]),
152
                })
153
 
154
            now = time.perf_counter()
155
            if not args.quiet and now - t_report >= 5.0:
156
                console.print(f"[dim]  pci:  {frames} frames, {hits} hits ({method_label})[/dim]")
157
                t_report = now
158
    except TimeoutError as e:
159
        console.print(f"[red]pci_tracker: {e}. Is capture/ingest.py running?[/red]")
160
        return 2
161
    except KeyboardInterrupt:
162
        console.print("[yellow]pci_tracker interrupted.[/yellow]")
163
    finally:
164
        sub.close(0)
165
        pub.close(0)
166
    console.print(f"[bold]pci_tracker summary:[/bold] frames={frames} hits={hits} method={method_label}")
167
    return 0
168
 
169
if __name__ == "__main__":
170
    sys.exit(main())