Zion Boggan
repos/Pitch Tracker CV/cv/ball_tracker.py
zionboggan.com ↗
278 lines · python
History for this file →
1
"""Ball tracker.
2
 
3
Subscribes to capture frames, finds the ball via HSV + circularity, keeps a
4
rolling window of detections, and emits:
5
  - ball_track events each frame with the most-confident detection (or miss)
6
  - pitch_pred events once a fittable trajectory accumulates (plate_x + eta_ms)
7
 
8
Classical CV only. Tune HSV and radius bounds via configs/runtime.yaml.
9
"""
10
from __future__ import annotations
11
 
12
import argparse
13
import sys
14
import time
15
from collections import deque
16
from pathlib import Path
17
 
18
import cv2
19
import numpy as np
20
from rich.console import Console
21
 
22
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
23
from cv._common import (
24
    event_subscriber,
25
    iter_latest_frames,
26
    load_config,
27
    make_frame_subscriber,
28
    make_pub,
29
    send_event,
30
)
31
 
32
console = Console()
33
 
34
TRAJ_WINDOW_S = 0.8
35
MIN_FIT_POINTS = 4
36
FIT_USE_LAST_N = 8
37
PITCH_GAP_MS = 300
38
MIN_DOWN_PX = 30
39
STATIC_WINDOW = 3
40
STATIC_SPREAD_PX = 3.0
41
BAN_ZONE_MS = 1500
42
BAN_ZONE_PX = 6
43
PITCH_START_MAX_Y = 500
44
 
45
MAX_STEP_PX = 200
46
 
47
class Detection:
48
    __slots__ = ("ts_ns", "x", "y", "r", "score")
49
    def __init__(self, ts_ns: int, x: float, y: float, r: float, score: float):
50
        self.ts_ns = ts_ns
51
        self.x = x
52
        self.y = y
53
        self.r = r
54
        self.score = score
55
 
56
def detect_ball(frame: np.ndarray, cfg: dict) -> Detection | None:
57
    pd = cfg["cv"]["pitch_detect"]
58
    hsv_low = np.array(pd["ball_hsv_low"], dtype=np.uint8)
59
    hsv_high = np.array(pd["ball_hsv_high"], dtype=np.uint8)
60
    r_min = float(pd["min_ball_radius_px"])
61
    r_max = float(pd["max_ball_radius_px"])
62
 
63
    hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)
64
    mask = cv2.inRange(hsv, hsv_low, hsv_high)
65
    mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, np.ones((3, 3), np.uint8), iterations=1)
66
 
67
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
68
    best: Detection | None = None
69
    best_score = 0.0
70
    for c in contours:
71
        area = cv2.contourArea(c)
72
        if area < np.pi * r_min * r_min * 0.5:
73
            continue
74
        (cx, cy), r = cv2.minEnclosingCircle(c)
75
        if r < r_min or r > r_max:
76
            continue
77
 
78
        circ_area = np.pi * r * r
79
        score = float(area / circ_area) if circ_area > 0 else 0.0
80
        if score < 0.6:
81
            continue
82
        if score > best_score:
83
            best_score = score
84
            best = Detection(ts_ns=0, x=float(cx), y=float(cy), r=float(r), score=score)
85
    return best
86
 
87
def try_fit(trail: deque[Detection], plate_y_px: float) -> tuple[float, float] | None:
88
    """Return (plate_x_px, eta_ms_from_now) or None if not fittable."""
89
    if len(trail) < MIN_FIT_POINTS:
90
        return None
91
 
92
    recent = list(trail)[-FIT_USE_LAST_N:]
93
    ys = np.array([d.y for d in recent], dtype=np.float64)
94
    if ys.max() - ys.min() < MIN_DOWN_PX:
95
        return None
96
 
97
    if (ys[-1] - ys[0]) < -30:
98
        return None
99
 
100
    t0 = recent[0].ts_ns
101
    ts = np.array([(d.ts_ns - t0) / 1e9 for d in recent], dtype=np.float64)
102
    xs = np.array([d.x for d in recent], dtype=np.float64)
103
 
104
    ay, by, cy = np.polyfit(ts, ys, 2)
105
    bx, cx = np.polyfit(ts, xs, 1)
106
 
107
    disc = by * by - 4 * ay * (cy - plate_y_px)
108
    if disc < 0 or abs(ay) < 1e-6:
109
        return None
110
    sqrt_d = float(np.sqrt(disc))
111
    t_candidates = [(-by + sqrt_d) / (2 * ay), (-by - sqrt_d) / (2 * ay)]
112
    t_cross = None
113
    for tc in t_candidates:
114
        if tc > ts[-1]:
115
            if t_cross is None or tc < t_cross:
116
                t_cross = tc
117
    if t_cross is None:
118
        return None
119
    now_ns = time.time_ns()
120
    plate_ns = t0 + int(t_cross * 1e9)
121
    eta_ms = (plate_ns - now_ns) / 1e6
122
    if eta_ms < 0 or eta_ms > 2000:
123
        return None
124
    plate_x = bx * t_cross + cx
125
    return float(plate_x), float(eta_ms)
126
 
127
def main() -> int:
128
    ap = argparse.ArgumentParser(description="Ball tracker + parabolic pitch-prediction.")
129
    ap.add_argument("--duration", type=float, default=0.0, help="Stop after N seconds (0 = run forever).")
130
    ap.add_argument("--quiet", action="store_true", help="Suppress per-frame status lines.")
131
    args = ap.parse_args()
132
 
133
    cfg = load_config()
134
    capture_ep = cfg["capture"]["publish_endpoint"]
135
    ball_ep = cfg["cv"]["ball_events_endpoint"]
136
    plate_y_frac = float(cfg["cv"].get("plate_y_frac", 0.72))
137
 
138
    sub = make_frame_subscriber(capture_ep)
139
    pub = make_pub(ball_ep)
140
    console.print(f"[green]ball_tracker[/green] sub={capture_ep} pub={ball_ep}")
141
 
142
    trail: deque[Detection] = deque(maxlen=64)
143
    banned: deque = deque(maxlen=32)
144
    pitch_id = 0
145
    last_det_ns: int | None = None
146
    t_end = time.perf_counter() + args.duration if args.duration > 0 else None
147
    frames = 0
148
    hits = 0
149
    preds = 0
150
    t_report = time.perf_counter()
151
 
152
    try:
153
        for meta, frame in iter_latest_frames(sub, timeout_ms=3000):
154
            if t_end is not None and time.perf_counter() >= t_end:
155
                break
156
            frames += 1
157
            ts_ns = int(meta["ts_ns"])
158
            h = int(meta.get("h", frame.shape[0]))
159
            plate_y_px = h * plate_y_frac
160
 
161
            while banned and banned[0][2] <= ts_ns:
162
                banned.popleft()
163
 
164
            det = detect_ball(frame, cfg)
165
            if det is not None:
166
 
167
                in_banned = any(
168
                    abs(det.x - bx) < BAN_ZONE_PX and abs(det.y - by) < BAN_ZONE_PX
169
                    for (bx, by, _) in banned
170
                )
171
                if in_banned:
172
                    send_event(pub, {
173
                        "type": "ball_miss",
174
                        "seq": int(meta["seq"]),
175
                        "ts_ns": ts_ns,
176
                        "reason": "banned_zone",
177
                    })
178
                    continue
179
 
180
                det.ts_ns = ts_ns
181
 
182
                if last_det_ns is not None and (ts_ns - last_det_ns) > PITCH_GAP_MS * 1e6:
183
                    trail.clear()
184
                    pitch_id += 1
185
 
186
                if trail:
187
                    last_d = trail[-1]
188
                    dx = det.x - last_d.x
189
                    dy_step = det.y - last_d.y
190
                    if (dx * dx + dy_step * dy_step) ** 0.5 > MAX_STEP_PX:
191
                        trail.clear()
192
                        pitch_id += 1
193
 
194
                if len(trail) == 0 and det.y > PITCH_START_MAX_Y:
195
                    send_event(pub, {
196
                        "type": "ball_miss",
197
                        "seq": int(meta["seq"]),
198
                        "ts_ns": ts_ns,
199
                        "reason": "not_pitch_start",
200
                        "det_y": det.y,
201
                    })
202
                    continue
203
 
204
                trail.append(det)
205
                last_det_ns = ts_ns
206
 
207
                cutoff = ts_ns - int(TRAJ_WINDOW_S * 1e9)
208
                while trail and trail[0].ts_ns < cutoff:
209
                    trail.popleft()
210
 
211
                if len(trail) >= STATIC_WINDOW:
212
                    recent = list(trail)[-STATIC_WINDOW:]
213
                    rxs = [d.x for d in recent]
214
                    rys = [d.y for d in recent]
215
                    spread = max(max(rxs) - min(rxs), max(rys) - min(rys))
216
                    if spread < STATIC_SPREAD_PX:
217
                        cx = sum(rxs) / len(rxs)
218
                        cy = sum(rys) / len(rys)
219
                        banned.append((cx, cy, ts_ns + int(BAN_ZONE_MS * 1e6)))
220
                        trail.clear()
221
                        last_det_ns = None
222
                        send_event(pub, {
223
                            "type": "ball_miss",
224
                            "seq": int(meta["seq"]),
225
                            "ts_ns": ts_ns,
226
                            "reason": "static_ui_banned",
227
                            "banned_x": cx, "banned_y": cy,
228
                        })
229
                        continue
230
 
231
                hits += 1
232
                send_event(pub, {
233
                    "type": "ball_track",
234
                    "seq": int(meta["seq"]),
235
                    "ts_ns": ts_ns,
236
                    "pitch_id": pitch_id,
237
                    "x": det.x, "y": det.y, "r": det.r,
238
                    "score": det.score,
239
                })
240
 
241
                fit = try_fit(trail, plate_y_px)
242
                if fit is not None:
243
                    plate_x, eta_ms = fit
244
                    preds += 1
245
                    send_event(pub, {
246
                        "type": "pitch_pred",
247
                        "seq": int(meta["seq"]),
248
                        "ts_ns": ts_ns,
249
                        "pitch_id": pitch_id,
250
                        "plate_x": plate_x,
251
                        "plate_y": plate_y_px,
252
                        "eta_ms": eta_ms,
253
                        "n_points": len(trail),
254
                    })
255
            else:
256
                send_event(pub, {
257
                    "type": "ball_miss",
258
                    "seq": int(meta["seq"]),
259
                    "ts_ns": ts_ns,
260
                })
261
 
262
            now = time.perf_counter()
263
            if not args.quiet and now - t_report >= 5.0:
264
                console.print(f"[dim]  ball: {frames} frames, {hits} hits, {preds} preds, trail={len(trail)}[/dim]")
265
                t_report = now
266
    except TimeoutError as e:
267
        console.print(f"[red]ball_tracker: {e}. Is capture/ingest.py running?[/red]")
268
        return 2
269
    except KeyboardInterrupt:
270
        console.print("[yellow]ball_tracker interrupted.[/yellow]")
271
    finally:
272
        sub.close(0)
273
        pub.close(0)
274
    console.print(f"[bold]ball_tracker summary:[/bold] frames={frames} hits={hits} preds={preds}")
275
    return 0
276
 
277
if __name__ == "__main__":
278
    sys.exit(main())