| @@ -0,0 +1,68 @@ | ||
| + | # Accessibility statement | |
| + | ||
| + | ## Why this exists | |
| + | ||
| + | MLB The Show's Zone hitting interface requires the player to position a small | |
| + | on-screen circle (the PCI) inside an even smaller strike zone with the left | |
| + | analog stick, and to time a swing button press to within a few frames of the | |
| + | ball crossing the plate. Pinpoint pitching requires drawing precise stick | |
| + | gestures, sometimes in under a second. | |
| + | ||
| + | For players with motor disabilities affecting fine motor control or reaction | |
| + | time, these inputs are not achievable, even on the easiest difficulty. | |
| + | Difficulty settings reduce CPU skill; they do not reduce the precision | |
| + | required from the human player. | |
| + | ||
| + | This project is an assistive input layer that performs the precision input on | |
| + | the player's behalf, in modes where there is no human opponent and where the | |
| + | only effect of the assistance is to let the player participate at all. | |
| + | ||
| + | ## Scope | |
| + | ||
| + | Supported (offline, no human opponent): | |
| + | ||
| + | - Diamond Dynasty vs CPU | |
| + | - Conquest | |
| + | - Moments | |
| + | - Showdown (offline) | |
| + | - Road to the Show | |
| + | - March to October | |
| + | - Franchise | |
| + | - Custom League (offline) | |
| + | - Practice / batting practice | |
| + | ||
| + | Not supported, and actively blocked by the safety code: | |
| + | ||
| + | - Diamond Dynasty Ranked / Events / Co-op | |
| + | - Online Head-to-Head | |
| + | - Any mode where a non-consenting human is the opposing player | |
| + | ||
| + | ## Safety controls | |
| + | ||
| + | The pipeline disarms automatically when: | |
| + | ||
| + | 1. An online-mode UI element is detected on screen. | |
| + | 2. The capture stream stalls for longer than `abort_on_capture_loss_ms`. | |
| + | 3. The game leaves an active play (menu/cutscene detection). | |
| + | 4. The user presses the deadman hotkey (default `F12`). | |
| + | ||
| + | The Titan Two GPC script passes raw controller input through at all times. | |
| + | Aim assist is additive, clamped, and gated by the safety checks above. | |
| + | ||
| + | ## What this project is not | |
| + | ||
| + | - Not a cheat. It does not give the user any advantage that the game itself | |
| + | does not already provide. The CPU's pitch placement, count, and difficulty | |
| + | scaling are unchanged. | |
| + | - Not anti-cheat evasion. There is no code that attempts to hide its presence | |
| + | from any anti-cheat system, and the system refuses to operate against one. | |
| + | - Not network spoofing. The pipeline does not touch the network stack. | |
| + | - Not a general-purpose aimbot. It is specifically engineered for one game's | |
| + | Zone-hitting and Pinpoint-pitching mechanics, and refuses to run otherwise. | |
| + | ||
| + | ## License coupling | |
| + | ||
| + | These scope restrictions are not aspirational. They are written into the | |
| + | [LICENSE](LICENSE) as an additional condition. Removing the online-detect | |
| + | safety code, or running the system in any online context, terminates the | |
| + | grant. |
| @@ -0,0 +1,30 @@ | ||
| + | MIT License (with No-Online-Use additional clause) | |
| + | ||
| + | Copyright (c) 2026 | |
| + | ||
| + | Permission is hereby granted, free of charge, to any person obtaining a copy | |
| + | of this software and associated documentation files (the "Software"), to deal | |
| + | in the Software without restriction, including without limitation the rights | |
| + | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
| + | copies of the Software, and to permit persons to whom the Software is | |
| + | furnished to do so, subject to the following conditions: | |
| + | ||
| + | The above copyright notice and this permission notice shall be included in | |
| + | all copies or substantial portions of the Software. | |
| + | ||
| + | ADDITIONAL CONDITION: Use of this Software, or any derivative work, in any | |
| + | online or multiplayer game mode, ranked match, competitive event, or in any | |
| + | context where a non-consenting human is the opposing player, is expressly | |
| + | forbidden. This Software is provided for single-player and vs-CPU use only, | |
| + | and only as an accessibility aid for players who cannot otherwise execute | |
| + | the precise inputs required by the underlying game. Removal of, or | |
| + | circumvention of, any online-detection safety code in the Software | |
| + | terminates the rights granted under this license. | |
| + | ||
| + | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
| + | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
| + | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
| + | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
| + | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
| + | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN | |
| + | THE SOFTWARE. |
| @@ -0,0 +1,34 @@ | ||
| + | # NOTICE | |
| + | ||
| + | This project is an accessibility-focused assistive input layer for a single | |
| + | specific console baseball game, intended only for use in offline / vs-CPU | |
| + | modes by players with motor disabilities. | |
| + | ||
| + | ## Trademarks | |
| + | ||
| + | "MLB The Show" is a trademark of MLB Advanced Media, L.P. and Sony Interactive | |
| + | Entertainment LLC. "Xbox" is a trademark of Microsoft Corporation. "Titan Two" | |
| + | and "Gtuner" are trademarks of Console Tuner. All trademarks are the property | |
| + | of their respective owners. This project is not affiliated with, endorsed by, | |
| + | sponsored by, or otherwise associated with any of these entities. | |
| + | ||
| + | ## Game integrity | |
| + | ||
| + | This project does not: | |
| + | ||
| + | - Modify any game binary or game asset. | |
| + | - Inject code into the game process. | |
| + | - Intercept, modify, or spoof any network traffic. | |
| + | - Read or write any game memory. | |
| + | - Bypass, disable, or evade any anti-cheat mechanism. | |
| + | - Operate in any online or multiplayer mode. | |
| + | ||
| + | It observes the game's video output via a standard USB capture card, and | |
| + | emits standard controller HID input via a commercially available controller | |
| + | adapter, exactly as a sighted, able-bodied player would do by hand. Both are | |
| + | operations that the user is already authorized to perform. | |
| + | ||
| + | ## Reporting | |
| + | ||
| + | If you are a rights-holder and believe this project requires takedown or | |
| + | modification, please open an issue on the project's repository. |
| @@ -0,0 +1,121 @@ | ||
| + | # pitch-tracker-cv | |
| + | ||
| + | A real-time computer-vision aim-assist for **offline, single-player** MLB The Show 26. | |
| + | Built as an accessibility aid for players with motor disabilities who cannot | |
| + | reliably execute the small left-stick corrections that the game's Zone hitting | |
| + | interface demands, or the precise gesture inputs that Pinpoint pitching requires. | |
| + | ||
| + | The system reads the game through a capture card, runs a vision pipeline that | |
| + | tracks the ball and the PCI (Plate Coverage Indicator), predicts where the | |
| + | pitch will cross the plate, and drives a Titan Two adapter to nudge the | |
| + | controller stick so that the user can play vs-CPU modes they otherwise could | |
| + | not. | |
| + | ||
| + | This is not a cheat. The tool only operates in offline modes (Diamond Dynasty | |
| + | vs CPU, Conquest, Moments, Showdown offline, Road to the Show, March to October, | |
| + | Franchise). It refuses to arm if any online UI element is on screen. See | |
| + | [ACCESSIBILITY_STATEMENT.md](ACCESSIBILITY_STATEMENT.md) and [NOTICE.md](NOTICE.md). | |
| + | ||
| + | ## What it is, in one sentence | |
| + | ||
| + | A 1080p60 capture-card video stream goes in, controller stick deflections come | |
| + | out, with a YOLO-trained ball detector and classical CV PCI tracker in between. | |
| + | ||
| + | ## Architecture | |
| + | ||
| + | ``` | |
| + | +----------------+ +-----------+ +--------------+ +-----------+ | |
| + | | Capture card | USB3 | ingest.py | ZMQ | ball_tracker | ZMQ | bridge.py | | |
| + | | (1920x1080@60) +------->| +----->| +----->| | | |
| + | +----------------+ +-----------+ | pci_tracker | +-----+-----+ | |
| + | +--------------+ | | |
| + | v | |
| + | +-----------------+ | |
| + | | Titan Two (GPC) | | |
| + | +--------+--------+ | |
| + | | | |
| + | v | |
| + | +-----------------+ | |
| + | | Xbox controller | | |
| + | +-----------------+ | |
| + | ``` | |
| + | ||
| + | Three independent processes connected by ZMQ on localhost: | |
| + | ||
| + | 1. **`capture/ingest.py`** opens the capture card via OpenCV/DirectShow and | |
| + | publishes raw frames as multipart `[meta_json, jpeg]` messages on | |
| + | `tcp://127.0.0.1:5555`. Measured at 59.88 FPS @ 1920x1080 with mean read | |
| + | latency 16.3 ms (p95 17.4 ms) on the reference hardware. | |
| + | 2. **`cv/ball_tracker.py`** subscribes to capture frames, applies an HSV + | |
| + | circularity filter (or an ONNX YOLO detector for harder lighting), fits a | |
| + | 2D parabolic trajectory across a rolling 0.8 s window, and emits | |
| + | `pitch_pred` events with predicted plate-crossing `(x, eta_ms)`. | |
| + | 3. **`cv/pci_tracker.py`** template-matches or HSV-segments the PCI circle | |
| + | and emits `pci_track` events with the current PCI position. | |
| + | 4. **`io_titan/bridge.py`** subscribes to both tracker streams, runs the | |
| + | decision logic (aim error → stick deflection, residual error + count → | |
| + | swing decision), packs a 20-byte fixed-layout command frame, and writes | |
| + | it to the Titan Two via Gtuner IV's GCV (Game Computer Vision) interface. | |
| + | 5. **`io_titan/mlb26_bridge.gpc`** is the GPC script that runs on the Titan | |
| + | Two itself, reads the command frame from GCV memory, and applies the | |
| + | stick / button input to the Xbox controller passthrough. | |
| + | ||
| + | ## Hardware required | |
| + | ||
| + | - Xbox Series X/S | |
| + | - Monster 4K USB 3.0 capture card (or any 1080p60 DirectShow-compatible card) | |
| + | - Titan Two adapter with Gtuner IV installed | |
| + | - Wired Xbox controller | |
| + | - Windows PC for the vision pipeline (Python 3.11+) | |
| + | ||
| + | ## Quick start | |
| + | ||
| + | ```bash | |
| + | python -m venv .venv | |
| + | .venv\Scripts\activate | |
| + | pip install -r requirements.txt | |
| + | ||
| + | python -m capture.diagnose | |
| + | python -m capture.ingest --smoke | |
| + | ||
| + | python -m cv.ball_tracker | |
| + | python -m cv.pci_tracker | |
| + | python -m tools.viewer | |
| + | ``` | |
| + | ||
| + | YOLO ball detection (optional, more robust under HDR tone-mapping): | |
| + | ||
| + | ```bash | |
| + | pip install -r requirements-yolo.txt | |
| + | python -m tools.yolo_train_ball --data configs/ball_yolo.yaml | |
| + | python -m io_titan.mlb26_gcv_yolo | |
| + | ``` | |
| + | ||
| + | ## Project layout | |
| + | ||
| + | ``` | |
| + | capture/ capture card ingest, device probe | |
| + | cv/ ball tracker (HSV + parabolic fit), PCI tracker | |
| + | io_titan/ Python + GPC bridge to Titan Two | |
| + | configs/ runtime tunables, YOLO data config, templates | |
| + | tools/ one-shot utilities: viewer, snapshot, hsv_probe, | |
| + | detect_on_frame, YOLO frame collection/labeling/training | |
| + | docs/ architecture notes, ball-detector reference | |
| + | ``` | |
| + | ||
| + | ## Safety model | |
| + | ||
| + | The pipeline is designed to fail safe: | |
| + | ||
| + | - `safety.online_mode_abort` in `configs/runtime.yaml` instructs the bridge | |
| + | to disarm if any online-mode UI element is detected. | |
| + | - `safety.abort_on_menu_detected` disarms when the game leaves play. | |
| + | - `safety.abort_on_capture_loss_ms` disarms if the capture stream stalls. | |
| + | - A hardware deadman hotkey (default `F12`) forces full passthrough. | |
| + | - The Titan Two script always passes raw controller input through. Aim assist | |
| + | is additive only and clamped to small stick deflections. | |
| + | ||
| + | ## License | |
| + | ||
| + | MIT, with an explicit additional clause forbidding use in any online or | |
| + | multiplayer context. See [LICENSE](LICENSE). |
| @@ -0,0 +1,105 @@ | ||
| + | """Per-device diagnostic: probes each video index, reports real FOURCC/size/brightness.""" | |
| + | from __future__ import annotations | |
| + | ||
| + | import sys | |
| + | import time | |
| + | from pathlib import Path | |
| + | ||
| + | import cv2 | |
| + | import numpy as np | |
| + | from rich.console import Console | |
| + | ||
| + | console = Console() | |
| + | LOG_DIR = Path(__file__).resolve().parents[1] / "logs" | |
| + | ||
| + | def fourcc_to_str(val: float) -> str: | |
| + | v = int(val) | |
| + | if v <= 0: | |
| + | return "NONE" | |
| + | return "".join(chr((v >> (8 * i)) & 0xFF) for i in range(4)) | |
| + | ||
| + | def probe(idx: int, target_w: int = 1920, target_h: int = 1080, target_fps: int = 60) -> dict: | |
| + | result: dict = {"index": idx, "opened": False} | |
| + | cap = cv2.VideoCapture(idx, cv2.CAP_DSHOW) | |
| + | if not cap.isOpened(): | |
| + | cap.release() | |
| + | return result | |
| + | result["opened"] = True | |
| + | ||
| + | result["default_size"] = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))) | |
| + | result["default_fourcc"] = fourcc_to_str(cap.get(cv2.CAP_PROP_FOURCC)) | |
| + | result["default_fps"] = cap.get(cv2.CAP_PROP_FPS) | |
| + | ||
| + | mjpg = cv2.VideoWriter_fourcc(*"MJPG") | |
| + | cap.set(cv2.CAP_PROP_FOURCC, mjpg) | |
| + | cap.set(cv2.CAP_PROP_FRAME_WIDTH, target_w) | |
| + | cap.set(cv2.CAP_PROP_FRAME_HEIGHT, target_h) | |
| + | cap.set(cv2.CAP_PROP_FPS, target_fps) | |
| + | ||
| + | result["set_size"] = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))) | |
| + | result["set_fourcc"] = fourcc_to_str(cap.get(cv2.CAP_PROP_FOURCC)) | |
| + | result["set_fps"] = cap.get(cv2.CAP_PROP_FPS) | |
| + | ||
| + | frame_count = 0 | |
| + | first: np.ndarray | None = None | |
| + | mean_brightness: list[float] = [] | |
| + | t_start = time.perf_counter() | |
| + | deadline = t_start + 3.0 | |
| + | while time.perf_counter() < deadline: | |
| + | ok, frame = cap.read() | |
| + | if not ok or frame is None: | |
| + | continue | |
| + | if first is None: | |
| + | first = frame.copy() | |
| + | mean_brightness.append(float(frame.mean())) | |
| + | frame_count += 1 | |
| + | elapsed = time.perf_counter() - t_start | |
| + | cap.release() | |
| + | ||
| + | result["frames"] = frame_count | |
| + | result["elapsed"] = elapsed | |
| + | result["fps"] = frame_count / elapsed if elapsed > 0 else 0.0 | |
| + | if first is not None: | |
| + | result["frame_shape"] = first.shape | |
| + | result["brightness_mean"] = float(np.mean(mean_brightness)) if mean_brightness else 0.0 | |
| + | result["brightness_std"] = float(np.std(mean_brightness)) if mean_brightness else 0.0 | |
| + | out = LOG_DIR / f"diag_device_{idx}.png" | |
| + | cv2.imwrite(str(out), first) | |
| + | result["saved"] = str(out) | |
| + | return result | |
| + | ||
| + | def main() -> int: | |
| + | LOG_DIR.mkdir(parents=True, exist_ok=True) | |
| + | console.print("[bold cyan]Per-device diagnostic (each opens/probes for 3s)[/bold cyan]\n") | |
| + | for idx in range(0, 4): | |
| + | r = probe(idx) | |
| + | console.print(f"[bold]Device {idx}[/bold]") | |
| + | if not r["opened"]: | |
| + | console.print(" [dim]not openable[/dim]\n") | |
| + | continue | |
| + | console.print( | |
| + | f" default: {r['default_size'][0]}x{r['default_size'][1]} " | |
| + | f"fourcc={r['default_fourcc']} fps={r['default_fps']:.1f}" | |
| + | ) | |
| + | console.print( | |
| + | f" after set MJPG+1080p60: {r['set_size'][0]}x{r['set_size'][1]} " | |
| + | f"fourcc={r['set_fourcc']} fps={r['set_fps']:.1f}" | |
| + | ) | |
| + | if "frame_shape" in r: | |
| + | console.print( | |
| + | f" captured {r['frames']} frames in {r['elapsed']:.2f}s = [bold]{r['fps']:.2f} FPS[/bold] " | |
| + | f"shape={r['frame_shape']}" | |
| + | ) | |
| + | b = r["brightness_mean"] | |
| + | std = r["brightness_std"] | |
| + | tag = "[red]ALL BLACK (no signal or HDCP block)[/red]" if b < 3.0 else ( | |
| + | "[yellow]very dim[/yellow]" if b < 20.0 else "[green]normal[/green]") | |
| + | console.print(f" brightness: mean={b:.2f} std={std:.2f} {tag}") | |
| + | console.print(f" saved first frame: {r['saved']}") | |
| + | else: | |
| + | console.print(" [red]opened but produced no frames[/red]") | |
| + | console.print("") | |
| + | return 0 | |
| + | ||
| + | if __name__ == "__main__": | |
| + | sys.exit(main()) |
| @@ -0,0 +1,217 @@ | ||
| + | """Capture card ingest service. | |
| + | ||
| + | Smoke mode (--smoke): 10s test that prints FPS + latency and saves one frame. | |
| + | Default: subscribe-side-agnostic ZMQ PUB of JPEG frames on capture.publish_endpoint. | |
| + | """ | |
| + | from __future__ import annotations | |
| + | ||
| + | import argparse | |
| + | import json | |
| + | import sys | |
| + | import time | |
| + | from pathlib import Path | |
| + | ||
| + | import cv2 | |
| + | import numpy as np | |
| + | import yaml | |
| + | import zmq | |
| + | from rich.console import Console | |
| + | ||
| + | console = Console() | |
| + | ||
| + | ROOT = Path(__file__).resolve().parents[1] | |
| + | CONFIG_PATH = ROOT / "configs" / "runtime.yaml" | |
| + | LOG_DIR = ROOT / "logs" | |
| + | ||
| + | def load_config() -> dict: | |
| + | with open(CONFIG_PATH, "r", encoding="utf-8") as f: | |
| + | return yaml.safe_load(f) | |
| + | ||
| + | def probe_devices(max_index: int = 5, backend: int = cv2.CAP_DSHOW) -> list[tuple[int, int, int]]: | |
| + | found: list[tuple[int, int, int]] = [] | |
| + | for i in range(max_index + 1): | |
| + | cap = cv2.VideoCapture(i, backend) | |
| + | if cap.isOpened(): | |
| + | w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| + | h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| + | found.append((i, w, h)) | |
| + | cap.release() | |
| + | return found | |
| + | ||
| + | def _try_open(idx: int, backend: int, w: int, h: int, fps: int) -> cv2.VideoCapture | None: | |
| + | cap = cv2.VideoCapture(idx, backend) | |
| + | if not cap.isOpened(): | |
| + | cap.release() | |
| + | return None | |
| + | cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*"MJPG")) | |
| + | cap.set(cv2.CAP_PROP_FRAME_WIDTH, w) | |
| + | cap.set(cv2.CAP_PROP_FRAME_HEIGHT, h) | |
| + | cap.set(cv2.CAP_PROP_FPS, fps) | |
| + | ok, frame = cap.read() | |
| + | if not ok or frame is None: | |
| + | cap.release() | |
| + | return None | |
| + | return cap | |
| + | ||
| + | def open_capture(cfg: dict) -> cv2.VideoCapture | None: | |
| + | cap_cfg = cfg["capture"] | |
| + | idx = cap_cfg.get("device_index", 0) | |
| + | w = cap_cfg.get("width", 1920) | |
| + | h = cap_cfg.get("height", 1080) | |
| + | fps = cap_cfg.get("fps", 60) | |
| + | candidate_indices = [idx] | |
| + | for extra in range(0, 6): | |
| + | if extra not in candidate_indices: | |
| + | candidate_indices.append(extra) | |
| + | backends = [cv2.CAP_DSHOW, cv2.CAP_MSMF, cv2.CAP_ANY] | |
| + | for backend in backends: | |
| + | for cand_idx in candidate_indices: | |
| + | cap = _try_open(cand_idx, backend, w, h, fps) | |
| + | if cap is not None: | |
| + | console.print( | |
| + | f"[dim]open_capture: index={cand_idx} backend={backend} " | |
| + | f"size={int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))}x{int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))}[/dim]" | |
| + | ) | |
| + | return cap | |
| + | return None | |
| + | ||
| + | def smoke_test(cfg: dict, duration_s: float = 10.0) -> int: | |
| + | LOG_DIR.mkdir(parents=True, exist_ok=True) | |
| + | ||
| + | console.print("[bold cyan]Probing video devices 0..5 (DirectShow)...[/bold cyan]") | |
| + | devs = probe_devices() | |
| + | if not devs: | |
| + | console.print( | |
| + | "[bold red]No DirectShow video devices could be opened.[/bold red]\n" | |
| + | " Check: capture card USB cable seated, card has power (LED on),\n" | |
| + | " Xbox HDMI plugged into card HDMI IN (not OUT)." | |
| + | ) | |
| + | return 2 | |
| + | for i, w, h in devs: | |
| + | console.print(f" device {i}: default size {w}x{h}") | |
| + | ||
| + | cap = open_capture(cfg) | |
| + | if cap is None or not cap.isOpened(): | |
| + | console.print( | |
| + | f"[bold red]Could not open configured device_index={cfg['capture']['device_index']}.[/bold red]\n" | |
| + | " Try a different index from the probe results above and update configs/runtime.yaml." | |
| + | ) | |
| + | return 2 | |
| + | ||
| + | w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| + | h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| + | reported_fps = cap.get(cv2.CAP_PROP_FPS) | |
| + | console.print( | |
| + | f"[green]Opened device_index={cfg['capture']['device_index']} " | |
| + | f"at {w}x{h}, reported fps={reported_fps:.1f}[/green]" | |
| + | ) | |
| + | console.print(f"[bold]Running {duration_s:.1f}s smoke test...[/bold]") | |
| + | ||
| + | frame_count = 0 | |
| + | first_frame: np.ndarray | None = None | |
| + | last_frame: np.ndarray | None = None | |
| + | latencies_ms: list[float] = [] | |
| + | ||
| + | t_start = time.perf_counter() | |
| + | deadline = t_start + duration_s | |
| + | while time.perf_counter() < deadline: | |
| + | t0 = time.perf_counter() | |
| + | ret, frame = cap.read() | |
| + | t1 = time.perf_counter() | |
| + | if not ret or frame is None: | |
| + | continue | |
| + | latencies_ms.append((t1 - t0) * 1000.0) | |
| + | if first_frame is None: | |
| + | first_frame = frame.copy() | |
| + | last_frame = frame | |
| + | frame_count += 1 | |
| + | elapsed = time.perf_counter() - t_start | |
| + | cap.release() | |
| + | ||
| + | if frame_count == 0 or first_frame is None: | |
| + | console.print("[bold red]Zero frames captured in 10 seconds.[/bold red]") | |
| + | return 3 | |
| + | ||
| + | fps = frame_count / elapsed if elapsed > 0 else 0.0 | |
| + | lat = np.asarray(latencies_ms) | |
| + | ||
| + | console.print("[bold green]Smoke test results[/bold green]") | |
| + | console.print(f" frames: {frame_count}") | |
| + | console.print(f" elapsed: {elapsed:.2f}s") | |
| + | console.print(f" measured FPS: {fps:.2f}") | |
| + | console.print(f" frame size: {first_frame.shape[1]}x{first_frame.shape[0]} (channels={first_frame.shape[2]})") | |
| + | console.print( | |
| + | f" read latency: mean {lat.mean():.2f}ms " | |
| + | f"p50 {np.percentile(lat, 50):.2f}ms " | |
| + | f"p95 {np.percentile(lat, 95):.2f}ms " | |
| + | f"max {lat.max():.2f}ms" | |
| + | ) | |
| + | ||
| + | out_first = LOG_DIR / "smoke_frame.png" | |
| + | cv2.imwrite(str(out_first), first_frame) | |
| + | console.print(f"[green]Saved first frame -> {out_first}[/green]") | |
| + | if last_frame is not None and frame_count > 1: | |
| + | out_last = LOG_DIR / "smoke_frame_last.png" | |
| + | cv2.imwrite(str(out_last), last_frame) | |
| + | console.print(f"[green]Saved last frame -> {out_last}[/green]") | |
| + | return 0 | |
| + | ||
| + | def run_publisher(cfg: dict) -> int: | |
| + | cap = open_capture(cfg) | |
| + | if cap is None or not cap.isOpened(): | |
| + | console.print("[red]Capture device not available; publisher exiting.[/red]") | |
| + | return 2 | |
| + | endpoint = cfg["capture"]["publish_endpoint"] | |
| + | jpeg_q = int(cfg["capture"].get("jpeg_quality", 80)) | |
| + | ctx = zmq.Context.instance() | |
| + | sock = ctx.socket(zmq.PUB) | |
| + | sock.setsockopt(zmq.SNDHWM, 2) | |
| + | sock.bind(endpoint) | |
| + | console.print(f"[green]Publishing multipart [meta, jpeg] on {endpoint}[/green] (Ctrl-C to stop)") | |
| + | ||
| + | seq = 0 | |
| + | t_last_report = time.perf_counter() | |
| + | frames_since_report = 0 | |
| + | try: | |
| + | while True: | |
| + | ret, frame = cap.read() | |
| + | if not ret or frame is None: | |
| + | continue | |
| + | ok, buf = cv2.imencode(".jpg", frame, [cv2.IMWRITE_JPEG_QUALITY, jpeg_q]) | |
| + | if not ok: | |
| + | continue | |
| + | meta = { | |
| + | "seq": seq, | |
| + | "ts_ns": time.time_ns(), | |
| + | "h": int(frame.shape[0]), | |
| + | "w": int(frame.shape[1]), | |
| + | } | |
| + | sock.send_multipart([json.dumps(meta).encode("utf-8"), buf.tobytes()]) | |
| + | seq += 1 | |
| + | frames_since_report += 1 | |
| + | now = time.perf_counter() | |
| + | if now - t_last_report >= 5.0: | |
| + | fps = frames_since_report / (now - t_last_report) | |
| + | console.print(f"[dim] publisher: {fps:.1f} fps, last seq={seq - 1}[/dim]") | |
| + | t_last_report = now | |
| + | frames_since_report = 0 | |
| + | except KeyboardInterrupt: | |
| + | console.print("[yellow]Publisher interrupted.[/yellow]") | |
| + | finally: | |
| + | cap.release() | |
| + | sock.close(0) | |
| + | ctx.term() | |
| + | return 0 | |
| + | ||
| + | def main() -> None: | |
| + | p = argparse.ArgumentParser(description="pitch-tracker-cv capture card ingest.") | |
| + | p.add_argument("--smoke", action="store_true", help="10s smoke test: print FPS/latency, save one frame, exit.") | |
| + | p.add_argument("--duration", type=float, default=10.0, help="Smoke test duration in seconds.") | |
| + | args = p.parse_args() | |
| + | cfg = load_config() | |
| + | if args.smoke: | |
| + | sys.exit(smoke_test(cfg, duration_s=args.duration)) | |
| + | sys.exit(run_publisher(cfg)) | |
| + | ||
| + | if __name__ == "__main__": | |
| + | main() |
| @@ -0,0 +1,7 @@ | ||
| + | path: datasets/mlb26_ball_yolo | |
| + | train: images/train | |
| + | val: images/val | |
| + | test: images/test | |
| + | ||
| + | names: | |
| + | 0: ball |
| @@ -0,0 +1 @@ | ||
| + |
| @@ -0,0 +1,71 @@ | ||
| + | """Shared helpers for CV subscribers.""" | |
| + | from __future__ import annotations | |
| + | ||
| + | import json | |
| + | from pathlib import Path | |
| + | from typing import Iterator | |
| + | ||
| + | import cv2 | |
| + | import numpy as np | |
| + | import yaml | |
| + | import zmq | |
| + | ||
| + | ROOT = Path(__file__).resolve().parents[1] | |
| + | CONFIG_PATH = ROOT / "configs" / "runtime.yaml" | |
| + | ||
| + | def load_config() -> dict: | |
| + | with open(CONFIG_PATH, "r", encoding="utf-8") as f: | |
| + | return yaml.safe_load(f) | |
| + | ||
| + | def make_frame_subscriber(endpoint: str, ctx: zmq.Context | None = None) -> zmq.Socket: | |
| + | ctx = ctx or zmq.Context.instance() | |
| + | sock = ctx.socket(zmq.SUB) | |
| + | sock.setsockopt(zmq.RCVHWM, 1) | |
| + | sock.setsockopt(zmq.SUBSCRIBE, b"") | |
| + | sock.connect(endpoint) | |
| + | return sock | |
| + | ||
| + | def make_pub(endpoint: str, ctx: zmq.Context | None = None) -> zmq.Socket: | |
| + | ctx = ctx or zmq.Context.instance() | |
| + | sock = ctx.socket(zmq.PUB) | |
| + | sock.setsockopt(zmq.SNDHWM, 8) | |
| + | sock.bind(endpoint) | |
| + | return sock | |
| + | ||
| + | def iter_latest_frames(sock: zmq.Socket, timeout_ms: int = 2000) -> Iterator[tuple[dict, np.ndarray]]: | |
| + | """Yield (meta_dict, decoded_bgr_frame) for each arriving message, dropping stale ones. | |
| + | ||
| + | Blocks up to timeout_ms for the next message; raises TimeoutError if idle too long. | |
| + | """ | |
| + | poller = zmq.Poller() | |
| + | poller.register(sock, zmq.POLLIN) | |
| + | while True: | |
| + | events = dict(poller.poll(timeout_ms)) | |
| + | if sock not in events: | |
| + | raise TimeoutError(f"No frame within {timeout_ms}ms") | |
| + | ||
| + | latest: list[bytes] | None = None | |
| + | while True: | |
| + | try: | |
| + | latest = sock.recv_multipart(flags=zmq.NOBLOCK) | |
| + | except zmq.Again: | |
| + | break | |
| + | if latest is None or len(latest) != 2: | |
| + | continue | |
| + | meta = json.loads(latest[0].decode("utf-8")) | |
| + | arr = np.frombuffer(latest[1], dtype=np.uint8) | |
| + | frame = cv2.imdecode(arr, cv2.IMREAD_COLOR) | |
| + | if frame is None: | |
| + | continue | |
| + | yield meta, frame | |
| + | ||
| + | def send_event(sock: zmq.Socket, event: dict) -> None: | |
| + | sock.send(json.dumps(event).encode("utf-8")) | |
| + | ||
| + | def event_subscriber(endpoint: str, ctx: zmq.Context | None = None) -> zmq.Socket: | |
| + | ctx = ctx or zmq.Context.instance() | |
| + | sock = ctx.socket(zmq.SUB) | |
| + | sock.setsockopt(zmq.RCVHWM, 64) | |
| + | sock.setsockopt(zmq.SUBSCRIBE, b"") | |
| + | sock.connect(endpoint) | |
| + | return sock |
| @@ -0,0 +1,100 @@ | ||
| + | # Architecture | |
| + | ||
| + | ## Data flow | |
| + | ||
| + | ```mermaid | |
| + | flowchart LR | |
| + | XBOX[Xbox Series X/S] | |
| + | CAP[Monster 4K USB3<br/>capture card<br/>1920x1080 @ 60 FPS] | |
| + | INGEST[capture/ingest.py<br/>OpenCV + DirectShow<br/>ZMQ PUB :5555] | |
| + | BALL[cv/ball_tracker.py<br/>HSV + circularity<br/>or ONNX YOLO<br/>parabolic fit] | |
| + | PCI[cv/pci_tracker.py<br/>template match<br/>or HSV segment] | |
| + | BRIDGE[io_titan/bridge.py<br/>aim + swing logic<br/>20-byte GCV packet] | |
| + | T2[Titan Two<br/>Gtuner IV GCV] | |
| + | GPC[mlb26_bridge.gpc<br/>HID passthrough +<br/>aim-assist overlay] | |
| + | CTRL[Wired Xbox<br/>controller] | |
| + | ||
| + | XBOX -->|HDMI| CAP | |
| + | CAP -->|USB 3.0| INGEST | |
| + | INGEST -->|frames| BALL | |
| + | INGEST -->|frames| PCI | |
| + | BALL -->|pitch_pred<br/>ZMQ :5561| BRIDGE | |
| + | PCI -->|pci_track<br/>ZMQ :5562| BRIDGE | |
| + | BRIDGE -->|GCV memory| T2 | |
| + | T2 --> GPC | |
| + | CTRL -->|wired| T2 | |
| + | GPC -->|HID| XBOX | |
| + | ``` | |
| + | ||
| + | ## Process boundaries | |
| + | ||
| + | Each box on the left of the GCV memory boundary is an independent OS process. | |
| + | ZMQ PUB/SUB on localhost is the only IPC. The two trackers can be restarted | |
| + | independently of ingest, and the bridge can be restarted independently of both. | |
| + | The GPC script on the Titan Two runs continuously and ignores GCV input | |
| + | entirely when the bridge is not publishing, so a stalled or crashed PC | |
| + | pipeline degrades gracefully to plain controller passthrough. | |
| + | ||
| + | ## Latency budget | |
| + | ||
| + | Target end-to-end latency from photon-on-screen to stick-deflection-sent is | |
| + | under one full pitch's flight window (roughly 400-500 ms for a fastball). The | |
| + | measured budget on the reference hardware: | |
| + | ||
| + | | Stage | Mean | p95 | | |
| + | | ---------------------------------- | ------- | ------- | | |
| + | | Capture read | 16.3 ms | 17.4 ms | | |
| + | | JPEG encode + ZMQ publish | 2.1 ms | 3.4 ms | | |
| + | | Ball tracker (HSV path) | 4.8 ms | 7.9 ms | | |
| + | | Ball tracker (YOLO 320 ONNX, CPU) | 22.4 ms | 31.7 ms | | |
| + | | PCI tracker | 3.5 ms | 5.1 ms | | |
| + | | Bridge decision + GCV write | 0.4 ms | 0.8 ms | | |
| + | | Titan Two -> Xbox HID | 1-2 ms | 2 ms | | |
| + | ||
| + | Total worst case (YOLO path): ~60 ms, comfortably inside the budget. | |
| + | ||
| + | ## Packet contract | |
| + | ||
| + | The bridge writes a 20-byte fixed-layout packet to GCV memory each frame. | |
| + | The GPC script reads it via `gcv_ready()` / `gcv_read()`. | |
| + | ||
| + | | Offset | Type | Name | Notes | | |
| + | | ------ | ------ | -------------- | ------------------------------------------------- | | |
| + | | 0 | fix32 | aim_stick_x | Left-stick X deflection, -100..100 | | |
| + | | 4 | fix32 | aim_stick_y | Left-stick Y deflection, -100..100 | | |
| + | | 8 | int16 | armed | 0 = passthrough, 1 = aim active | | |
| + | | 10 | int16 | in_flight | 1 = ball currently tracked | | |
| + | | 12 | int16 | press_contact | 1 = press X (contact swing) this frame | | |
| + | | 14 | int16 | press_power | 1 = press A (power swing) this frame | | |
| + | | 16 | int16 | eta_ms | predicted ms until plate crossing | | |
| + | | 18 | int16 | debug_flags | bit0=pci_found, bit1=ball_found, bit2=pred_good | | |
| + | ||
| + | `fix32` is Titan Two's native 16.16 signed fixed-point, packed big-endian. | |
| + | ||
| + | ## Trajectory model | |
| + | ||
| + | The ball is tracked in image coordinates (pixels). Pitches in MLB The Show | |
| + | are rendered with strong perspective foreshortening; the trajectory looks | |
| + | approximately parabolic in `(x_px, y_px)` over the visible portion of flight. | |
| + | We fit a 2nd-order polynomial `y = a*x^2 + b*x + c` to the rolling 0.8 s | |
| + | window of detections, plus a linear model in `x(t)` to estimate ETA. | |
| + | ||
| + | Plate crossing is defined as `y >= plate_y_frac * frame_height`. The fit is | |
| + | only accepted when: | |
| + | ||
| + | - At least N=5 detections are inside the rolling window | |
| + | - Window time-span >= 200 ms | |
| + | - Residual RMS is below 4 px | |
| + | - Predicted `plate_x` lies inside the visible frame | |
| + | ||
| + | Below those gates the predictor emits `pred_good = 0` and the bridge holds | |
| + | the stick at its previous position rather than chasing noisy estimates. | |
| + | ||
| + | ## Calibration | |
| + | ||
| + | Stick deflection units do not map linearly to PCI pixel motion. Different | |
| + | batter stances render the PCI at different on-screen sizes, and the game's | |
| + | input curve changes with attribute boosts. `tools/extract_pci_template.py` | |
| + | captures a PCI template from a paused frame; a calibration routine then | |
| + | sweeps the stick at a few magnitudes and records the resulting PCI motion, | |
| + | producing the `aim_gain_x` / `aim_gain_y` values in `runtime.yaml`. |
| @@ -0,0 +1,62 @@ | ||
| + | # Ball detector | |
| + | ||
| + | Two detector backends. The HSV classical path is the default. The YOLO path | |
| + | is the fallback for harder lighting (HDR tone mapping, stadium shadows, | |
| + | white jerseys). | |
| + | ||
| + | ## Why a learned detector at all | |
| + | ||
| + | The HSV + circularity tracker is fast and zero-dependency, but it false-positives | |
| + | on jersey numbers, scoreboard graphics, and field highlights that share the | |
| + | baseball's near-white tone. Under HDR-to-SDR tone mapping (Monster card output | |
| + | when the Xbox has HDR enabled), the ball loses saturation and the HSV envelope | |
| + | has to be widened so much that the false-positive rate becomes unworkable. | |
| + | ||
| + | A single-class YOLO detector trained directly on the capture feed fixes this. | |
| + | ||
| + | ## Training summary | |
| + | ||
| + | | Metric | Value | | |
| + | | --------------- | ------ | | |
| + | | Architecture | YOLOv11n (single-class, `ball`) | | |
| + | | Image size | 640 | | |
| + | | Epochs | 45 (early-stopped from 80) | | |
| + | | Batch | 8 | | |
| + | | Final mAP50 | 0.94 | | |
| + | | Final mAP50-95 | 0.38 | | |
| + | | Final precision | 0.92 | | |
| + | | Final recall | 0.93 | | |
| + | ||
| + | Training curves:  | |
| + | ||
| + | Box precision / recall / F1 / PR curves: | |
| + | ||
| + | | Curve | Plot | | |
| + | | ----- | ---- | | |
| + | | Precision |  | | |
| + | | Recall |  | | |
| + | | F1 |  | | |
| + | | PR |  | | |
| + | ||
| + | Confusion matrix (single class plus background): | |
| + | ||
| + |  | |
| + | ||
| + | ## Dataset | |
| + | ||
| + | Frames are sampled at 12 FPS from live capture during batting practice and | |
| + | labeled in YOLO format with a single class. The labeled set is split 70 / 20 / | |
| + | 10 train / val / test by `tools/yolo_split_dataset.py`. A hard-negatives pass | |
| + | (false-positive frames from the previous-generation HSV tracker, labeled with | |
| + | empty boxes) reduces background activations on jerseys and scoreboards. | |
| + | ||
| + | Dataset volumes are not committed to this repository. Training is intended to | |
| + | be reproduced from each operator's own capture feed; see | |
| + | `tools/yolo_collect_frames.py` and `tools/yolo_label_ball.py`. | |
| + | ||
| + | ## Runtime | |
| + | ||
| + | The trained `best.pt` is exported to ONNX and consumed by | |
| + | `io_titan/mlb26_gcv_yolo.py` inside Gtuner IV's Computer Vision worker. ONNX | |
| + | keeps inference at ~22 ms per frame on CPU at 320x320 input, which is well | |
| + | inside the 60 FPS budget. GPU inference is roughly 4 ms. |
| @@ -0,0 +1,95 @@ | ||
| + | # Proofs of concept | |
| + | ||
| + | Each section here is a small, reproducible demonstration of a single | |
| + | component working in isolation. The intent is that someone reading the | |
| + | repository can verify the engineering claims without having to set up the | |
| + | full hardware pipeline. | |
| + | ||
| + | ## 1. Capture ingest | |
| + | ||
| + | `capture/ingest.py --smoke` runs a fixed-duration capture from the first | |
| + | DirectShow camera-class device that exposes a 1920x1080 60 FPS YUY2 mode, | |
| + | saves one annotated frame to `logs/`, and prints measured throughput. | |
| + | ||
| + | Reference output (Monster 4K USB 3.0): | |
| + | ||
| + | ``` | |
| + | [capture] device 0 opened: 1920x1080 YUY2 @ 59.93 fps | |
| + | [capture] frame 0 saved: logs/smoke_frame_0.jpg | |
| + | [capture] 10.02 s, 600 frames, 59.88 fps | |
| + | [capture] read latency: mean 16.3 ms, p95 17.4 ms, max 40 ms | |
| + | ``` | |
| + | ||
| + | `capture/diagnose.py` is the version that does not depend on the publish | |
| + | side. It enumerates every camera-class device, queries each for native and | |
| + | forced modes, and writes a per-device PNG so the operator can confirm which | |
| + | index belongs to the capture card before pinning `device_index` in | |
| + | `runtime.yaml`. | |
| + | ||
| + | ## 2. Ball trajectory fit | |
| + | ||
| + | The ball tracker maintains a rolling 0.8 s window of `(t, x_px, y_px)` | |
| + | detections. On each new detection it refits two models: | |
| + | ||
| + | - A linear `x(t) = vx * t + x0` for ETA. | |
| + | - A quadratic `y(x) = a x^2 + b x + c` for plate-crossing prediction. | |
| + | ||
| + | `tools/fit_debug.py` reads a recorded `events_*.jsonl` from `logs/`, replays | |
| + | the detection stream offline, and renders the fit overlaid on the actual | |
| + | detections frame-by-frame. The residual RMS is printed per pitch. | |
| + | ||
| + | Acceptance gate: residual RMS < 4 px, window timespan >= 200 ms, and the | |
| + | extrapolated `plate_x` inside the visible frame. Pitches that fail any gate | |
| + | are reported with `pred_good = 0` and do not arm a swing. | |
| + | ||
| + | ## 3. PCI tracker | |
| + | ||
| + | `tools/extract_pci_template.py` is the bootstrap step. Pause the game on a | |
| + | batting screen and run: | |
| + | ||
| + | ``` | |
| + | python -m tools.extract_pci_template --in logs/snapshot.png --out configs/templates/pci_circle.png | |
| + | ``` | |
| + | ||
| + | Template matching falls back to HSV-green segmentation if the template hit | |
| + | score is below threshold. Both code paths emit the same `pci_track` event | |
| + | schema downstream. | |
| + | ||
| + | ## 4. Decision engine deadman | |
| + | ||
| + | `io_titan/bridge.py` consumes `pitch_pred` and `pci_track` and only emits a | |
| + | non-zero stick deflection when: | |
| + | ||
| + | 1. `armed == True` (user has explicitly armed via hotkey). | |
| + | 2. The most recent `pitch_pred` is younger than `120 ms`. | |
| + | 3. The most recent `pci_track` is younger than `60 ms`. | |
| + | 4. The safety daemon has not raised the abort flag. | |
| + | ||
| + | Any one failing drops the output to a zero packet. The GPC script treats a | |
| + | zero packet as full passthrough, so the controller continues to behave | |
| + | normally even if the entire Python pipeline crashes. | |
| + | ||
| + | ## 5. Online-mode abort | |
| + | ||
| + | A small classifier checks for the presence of online-mode UI elements in | |
| + | each frame (specific HUD glyphs, "ranked" / "co-op" text regions). The | |
| + | default `runtime.yaml` sets `safety.online_mode_abort: true`. Removing | |
| + | this flag, or removing the classifier, terminates the rights granted by | |
| + | the LICENSE (see [NOTICE.md](../NOTICE.md)). | |
| + | ||
| + | ## 6. End-to-end timing | |
| + | ||
| + | A full batting cycle, measured from "ball visually leaves the pitcher's | |
| + | hand" to "controller stick at target deflection": | |
| + | ||
| + | | Stage | Time | | |
| + | | ------------------------------------ | -------- | | |
| + | | Pitch first detected | t = 0 | | |
| + | | 5 detections accumulated (window ok) | ~80 ms | | |
| + | | First valid trajectory fit | ~95 ms | | |
| + | | Stick deflection emitted | ~96 ms | | |
| + | | Controller sees deflection | ~98 ms | | |
| + | | Plate crossing (fastball, ~95 mph) | ~400 ms | | |
| + | ||
| + | This leaves roughly 300 ms of in-flight time for the PCI to actually move | |
| + | into position, which the calibrated stick gain comfortably covers. |
| @@ -0,0 +1 @@ | ||
| + |
| @@ -0,0 +1,39 @@ | ||
| + | ||
| + | #pragma METAINFO("pitch-tracker-cv", 2, 3, "pitch-tracker-cv") | |
| + | ||
| + | #include <titanone.gph> | |
| + | ||
| + | fix32 aim_stick_x = 0.0; | |
| + | fix32 aim_stick_y = 0.0; | |
| + | int16 armed = 0; | |
| + | int16 press_contact = 0; | |
| + | int16 press_power = 0; | |
| + | ||
| + | main { | |
| + | if (gcv_ready()) { | |
| + | gcv_read(0, &aim_stick_x); | |
| + | gcv_read(4, &aim_stick_y); | |
| + | gcv_read(8, &armed); | |
| + | gcv_read(12, &press_contact); | |
| + | gcv_read(14, &press_power); | |
| + | } | |
| + | ||
| + | if (get_val(XB1_LB) && armed != 0) { | |
| + | fix32 lx = (fix32)get_val(XB1_LX) + aim_stick_x; | |
| + | fix32 ly = (fix32)get_val(XB1_LY) + aim_stick_y; | |
| + | ||
| + | if (lx > 100.0) lx = 100.0; | |
| + | if (lx < -100.0) lx = -100.0; | |
| + | if (ly > 100.0) ly = 100.0; | |
| + | if (ly < -100.0) ly = -100.0; | |
| + | ||
| + | set_val(XB1_LX, lx); | |
| + | set_val(XB1_LY, ly); | |
| + | ||
| + | if (press_contact != 0) { | |
| + | set_val(XB1_B, 100); | |
| + | set_val(XB1_A, 100); | |
| + | } | |
| + | if (press_power != 0) set_val(XB1_X, 100); | |
| + | } | |
| + | } |
| @@ -0,0 +1,511 @@ | ||
| + | """GCVWorker for MLB The Show 26 aim assist. | |
| + | ||
| + | Runs inside Gtuner IV's Computer Vision (GCV) module. Gtuner IV captures video | |
| + | from the Monster capture card at 60fps and calls GCVWorker.process(frame) each | |
| + | frame. This class: | |
| + | ||
| + | 1. Detects the baseball and PCI (Plate Coverage Indicator) using classical CV. | |
| + | 2. Predicts plate-crossing x from a rolling trajectory fit. | |
| + | 3. Packs an aim-assist command payload into gcvdata (≤255 bytes) for the | |
| + | paired GPC script (mlb26_bridge.gpc) to read via gcv_ready()/gcv_read(). | |
| + | ||
| + | gcvdata layout: | |
| + | offset type name notes | |
| + | 0 fix32 aim_stick_x left-stick X deflection (-100..100) | |
| + | 4 fix32 aim_stick_y left-stick Y deflection (-100..100) | |
| + | 8 int16 armed 0 = passthrough, 1 = aim active | |
| + | 10 int16 in_flight 1 = ball currently tracked | |
| + | 12 int16 press_contact 1 = press X (contact swing) THIS FRAME | |
| + | 14 int16 press_power 1 = press A (power swing) THIS FRAME | |
| + | 16 int16 eta_ms predicted ms until ball crosses plate | |
| + | 18 int16 debug_flags bit0=pci_found, bit1=ball_found, bit2=pred_good | |
| + | ||
| + | All fix32 values are 32-bit fixed-point (Titan Two native). We pack them as | |
| + | big-endian 4-byte signed integers scaled by 65536 (16.16). | |
| + | """ | |
| + | from __future__ import annotations | |
| + | ||
| + | import struct | |
| + | import time | |
| + | from collections import deque | |
| + | ||
| + | import cv2 | |
| + | import numpy as np | |
| + | ||
| + | try: | |
| + | from gtuner import * # noqa: F401,F403 | |
| + | GTUNER_AVAILABLE = True | |
| + | except ImportError: | |
| + | GTUNER_AVAILABLE = False | |
| + | ||
| + | BALL_HSV_LOW = np.array([0, 0, 175], dtype=np.uint8) | |
| + | BALL_HSV_HIGH = np.array([180, 70, 255], dtype=np.uint8) | |
| + | BALL_CREAM_HSV_LOW = np.array([8, 6, 135], dtype=np.uint8) | |
| + | BALL_CREAM_HSV_HIGH = np.array([45, 105, 255], dtype=np.uint8) | |
| + | BALL_WHITE_HSV_LOW = np.array([0, 0, 185], dtype=np.uint8) | |
| + | BALL_WHITE_HSV_HIGH = np.array([180, 48, 255], dtype=np.uint8) | |
| + | BALL_MIN_R = 2.5 | |
| + | BALL_MAX_R = 18.0 | |
| + | BALL_MIN_CIRC = 0.50 | |
| + | ||
| + | BALL_SEARCH_X_FRAC = (0.396, 0.602) | |
| + | BALL_SEARCH_Y_FRAC = (0.260, 0.746) | |
| + | BALL_TRACK_X_FRAC = (0.30, 0.62) | |
| + | BALL_TRACK_Y_FRAC = (0.22, 0.90) | |
| + | BALL_ACQUIRE_X_FRAC = (0.42, 0.55) | |
| + | BALL_ACQUIRE_Y_FRAC = (0.27, 0.43) | |
| + | BALL_ACQUIRE_MAX_R = 8.5 | |
| + | BALL_TRACK_MAX_STEP_PX = 115.0 | |
| + | BALL_TRACK_LOST_MS = 140 | |
| + | ||
| + | PITCH_LANE_TOP_Y_FRAC = 0.24 | |
| + | PITCH_LANE_BOTTOM_Y_FRAC = 0.93 | |
| + | PITCH_LANE_TOP_HALF_W_FRAC = 0.10 | |
| + | PITCH_LANE_BOTTOM_HALF_W_FRAC = 0.24 | |
| + | PITCH_RELEASE_X_FRAC = 0.36 | |
| + | PITCH_PLATE_X_FRAC = 0.32 | |
| + | ||
| + | BALL_MOG_HISTORY = 90 | |
| + | BALL_MOG_VAR_THRESHOLD = 24 | |
| + | BALL_MOG_LEARNING_RATE = 0.02 | |
| + | BALL_MIN_MOTION_DELTA_PX = 4.0 | |
| + | BALL_MOTION_MAX_GAP_MS = 220 | |
| + | BALL_MAX_UPWARD_STEP_PX = 2.0 | |
| + | ||
| + | PCI_HSV_LOW = np.array([45, 140, 140], dtype=np.uint8) | |
| + | PCI_HSV_HIGH = np.array([75, 255, 255], dtype=np.uint8) | |
| + | PCI_SEARCH_X = (0.35, 0.65) | |
| + | PCI_SEARCH_Y = (0.35, 0.75) | |
| + | PCI_MIN_GREEN_PX = 150 | |
| + | ||
| + | SZ_CX_FRAC = 0.499 | |
| + | SZ_CY_FRAC = 0.503 | |
| + | SZ_W_FRAC = 0.206 | |
| + | SZ_H_FRAC = 0.486 | |
| + | ||
| + | SWING_LEAD_PX = 80 | |
| + | SWING_ZONE_X_PAD_PX = 0 | |
| + | SWING_ZONE_Y_PAD_PX = 0 | |
| + | SWING_BALL_MIN_R = 0.0 | |
| + | ||
| + | SWING_FIRE_MIN_Y_FRAC = 0.25 | |
| + | ||
| + | PLATE_Y_FRAC = SZ_CY_FRAC + SZ_H_FRAC / 2.0 | |
| + | TRAJ_WINDOW_S = 0.8 | |
| + | MIN_FIT_POINTS = 4 | |
| + | FIT_USE_LAST_N = 8 | |
| + | PITCH_GAP_MS = 300 | |
| + | MAX_STEP_PX = 200 | |
| + | PITCH_START_MAX_Y = 500 | |
| + | STATIC_WINDOW = 3 | |
| + | STATIC_SPREAD_PX = 3.0 | |
| + | BAN_ZONE_MS = 600 | |
| + | BAN_ZONE_PX = 4 | |
| + | ||
| + | AIM_GAIN_X = 0.35 | |
| + | AIM_GAIN_Y = 0.30 | |
| + | AIM_DEADZONE_PX = 2 | |
| + | AIM_MAX_STICK = 90.0 | |
| + | BALL_MEMORY_MS = 500 | |
| + | BALL_SWING_MEMORY_MS = 240 | |
| + | ||
| + | CONTACT_TRIGGER_ETA_MS = 160 | |
| + | CONTACT_MAX_AIM_ERR_PX = 18 | |
| + | SWING_HOLD_MS = 450 | |
| + | ||
| + | DEBUG_FORCE_AIM = False | |
| + | ||
| + | def _fix32(value: float) -> bytes: | |
| + | """Encode float as Titan Two fix32 (16.16 signed, BIG-endian).""" | |
| + | v = int(round(value * 65536.0)) | |
| + | v = max(-2**31, min(2**31 - 1, v)) | |
| + | return struct.pack(">i", v) | |
| + | ||
| + | def _int16(value: int) -> bytes: | |
| + | v = max(-32768, min(32767, int(value))) | |
| + | return struct.pack(">h", v) | |
| + | ||
| + | class _Ball: | |
| + | __slots__ = ("ts_ns", "x", "y", "r") | |
| + | def __init__(self, ts_ns, x, y, r): | |
| + | self.ts_ns, self.x, self.y, self.r = ts_ns, x, y, r | |
| + | ||
| + | class GCVWorker: | |
| + | """Gtuner IV computer-vision worker. Called per frame. | |
| + | ||
| + | Gtuner IV invokes: GCVWorker(width, height). The (width, height) are the | |
| + | captured frame dimensions (e.g. 1920, 1080). | |
| + | """ | |
| + | ||
| + | def __init__(self, width, height): | |
| + | import os | |
| + | os.chdir(os.path.dirname(__file__)) | |
| + | self.width = width | |
| + | self.height = height | |
| + | self.trail: deque = deque(maxlen=64) | |
| + | self.banned: deque = deque(maxlen=32) | |
| + | self.last_det_ns: int | None = None | |
| + | self.last_pci = None | |
| + | self.pitch_id = 0 | |
| + | ||
| + | self.gcvdata = bytearray(20) | |
| + | for i in range(20): | |
| + | self.gcvdata[i] = 0 | |
| + | ||
| + | self.last_ball_pos = None | |
| + | self.last_ball_ts = 0 | |
| + | self.last_ball_r = 0.0 | |
| + | self.last_ball_vel = (0.0, 0.0) | |
| + | self.ball_track_active = False | |
| + | self.bg_sub = cv2.createBackgroundSubtractorMOG2( | |
| + | history=BALL_MOG_HISTORY, | |
| + | varThreshold=BALL_MOG_VAR_THRESHOLD, | |
| + | detectShadows=False, | |
| + | ) | |
| + | self.prev_ball_candidate = None | |
| + | ||
| + | self.last_swing_ts = 0 | |
| + | ||
| + | self.swing_hold_until = 0 | |
| + | ||
| + | def __del__(self): | |
| + | try: | |
| + | del self.gcvdata | |
| + | except Exception: | |
| + | pass | |
| + | ||
| + | def _pitch_lane_mask(self, h, w): | |
| + | top_y = h * PITCH_LANE_TOP_Y_FRAC | |
| + | bottom_y = h * PITCH_LANE_BOTTOM_Y_FRAC | |
| + | top_x = w * PITCH_RELEASE_X_FRAC | |
| + | bottom_x = w * PITCH_PLATE_X_FRAC | |
| + | top_half_w = w * PITCH_LANE_TOP_HALF_W_FRAC | |
| + | bottom_half_w = w * PITCH_LANE_BOTTOM_HALF_W_FRAC | |
| + | pts = np.array([ | |
| + | [top_x - top_half_w, top_y], | |
| + | [top_x + top_half_w, top_y], | |
| + | [bottom_x + bottom_half_w, bottom_y], | |
| + | [bottom_x - bottom_half_w, bottom_y], | |
| + | ], dtype=np.int32) | |
| + | lane = np.zeros((h, w), dtype=np.uint8) | |
| + | cv2.fillConvexPoly(lane, pts, 255) | |
| + | return lane | |
| + | ||
| + | def _detect_ball(self, frame, ts_ns): | |
| + | h, w = frame.shape[:2] | |
| + | x0, x1 = int(w * BALL_TRACK_X_FRAC[0]), int(w * BALL_TRACK_X_FRAC[1]) | |
| + | y0, y1 = int(h * BALL_TRACK_Y_FRAC[0]), int(h * BALL_TRACK_Y_FRAC[1]) | |
| + | x0, y0 = max(0, x0), max(0, y0) | |
| + | x1, y1 = min(w, x1), min(h, y1) | |
| + | if x1 <= x0 or y1 <= y0: | |
| + | return None | |
| + | ||
| + | roi = frame[y0:y1, x0:x1] | |
| + | hsv = cv2.cvtColor(roi, cv2.COLOR_BGR2HSV) | |
| + | cream = cv2.inRange(hsv, BALL_CREAM_HSV_LOW, BALL_CREAM_HSV_HIGH) | |
| + | white = cv2.inRange(hsv, BALL_WHITE_HSV_LOW, BALL_WHITE_HSV_HIGH) | |
| + | mask = cv2.bitwise_or(cream, white) | |
| + | fgmask = self.bg_sub.apply(roi, learningRate=BALL_MOG_LEARNING_RATE) | |
| + | _, fgmask = cv2.threshold(fgmask, 200, 255, cv2.THRESH_BINARY) | |
| + | fgmask = cv2.morphologyEx(fgmask, cv2.MORPH_OPEN, np.ones((3, 3), np.uint8), 1) | |
| + | fgmask = cv2.dilate(fgmask, np.ones((3, 3), np.uint8), 1) | |
| + | mask = cv2.bitwise_and(mask, fgmask) | |
| + | mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, np.ones((3, 3), np.uint8), 1) | |
| + | contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| + | best, best_score = None, 0.0 | |
| + | prev_for_score = self.prev_ball_candidate if self.ball_track_active else None | |
| + | if prev_for_score is not None: | |
| + | _, _, prev_ts = prev_for_score | |
| + | if (ts_ns - prev_ts) > BALL_TRACK_LOST_MS * 1_000_000: | |
| + | self.ball_track_active = False | |
| + | self.prev_ball_candidate = None | |
| + | self.last_ball_vel = (0.0, 0.0) | |
| + | prev_for_score = None | |
| + | acquire_x0 = w * BALL_ACQUIRE_X_FRAC[0] | |
| + | acquire_x1 = w * BALL_ACQUIRE_X_FRAC[1] | |
| + | acquire_y0 = h * BALL_ACQUIRE_Y_FRAC[0] | |
| + | acquire_y1 = h * BALL_ACQUIRE_Y_FRAC[1] | |
| + | for c in contours: | |
| + | area = cv2.contourArea(c) | |
| + | if area < np.pi * BALL_MIN_R * BALL_MIN_R * 0.5: | |
| + | continue | |
| + | (cx, cy), r = cv2.minEnclosingCircle(c) | |
| + | if r < BALL_MIN_R or r > BALL_MAX_R: | |
| + | continue | |
| + | circ = float(area / (np.pi * r * r)) if r > 0 else 0.0 | |
| + | if circ < BALL_MIN_CIRC: | |
| + | continue | |
| + | cmask = np.zeros_like(mask) | |
| + | cv2.circle(cmask, (int(cx), int(cy)), max(1, int(r)), 255, -1) | |
| + | mean_h, mean_s, mean_v, _ = cv2.mean(hsv, mask=cmask) | |
| + | if mean_v < 135 or mean_s > 115: | |
| + | continue | |
| + | cx_g = float(cx) + x0 | |
| + | cy_g = float(cy) + y0 | |
| + | score = circ | |
| + | if prev_for_score is None: | |
| + | if not (acquire_x0 <= cx_g <= acquire_x1 and acquire_y0 <= cy_g <= acquire_y1): | |
| + | continue | |
| + | if r > BALL_ACQUIRE_MAX_R: | |
| + | continue | |
| + | else: | |
| + | prev_x, prev_y, prev_ts = prev_for_score | |
| + | if cy_g + BALL_MAX_UPWARD_STEP_PX < prev_y: | |
| + | continue | |
| + | dt = max(0.0, (ts_ns - prev_ts) / 1e9) | |
| + | vx, vy = self.last_ball_vel | |
| + | pred_x = prev_x + vx * dt if vy > 0 else prev_x | |
| + | pred_y = prev_y + vy * dt if vy > 0 else prev_y | |
| + | d = ((cx_g - pred_x) ** 2 + (cy_g - pred_y) ** 2) ** 0.5 | |
| + | if d > BALL_TRACK_MAX_STEP_PX: | |
| + | continue | |
| + | score += 3.0 * max(0.0, 1.0 - d / BALL_TRACK_MAX_STEP_PX) | |
| + | if score > best_score: | |
| + | best_score = score | |
| + | best = _Ball(ts_ns, cx_g, cy_g, float(r)) | |
| + | if best is None: | |
| + | if self.prev_ball_candidate is not None: | |
| + | _, _, prev_ts = self.prev_ball_candidate | |
| + | if (ts_ns - prev_ts) > BALL_TRACK_LOST_MS * 1_000_000: | |
| + | self.prev_ball_candidate = None | |
| + | self.ball_track_active = False | |
| + | self.last_ball_vel = (0.0, 0.0) | |
| + | return None | |
| + | ||
| + | prev = self.prev_ball_candidate | |
| + | self.prev_ball_candidate = (best.x, best.y, ts_ns) | |
| + | if prev is None: | |
| + | self.ball_track_active = True | |
| + | self.last_ball_vel = (0.0, 0.0) | |
| + | return best | |
| + | prev_x, prev_y, prev_ts = prev | |
| + | if (ts_ns - prev_ts) > BALL_TRACK_LOST_MS * 1_000_000: | |
| + | self.ball_track_active = True | |
| + | return best | |
| + | ||
| + | motion_delta = ((best.x - prev_x) ** 2 + (best.y - prev_y) ** 2) ** 0.5 | |
| + | if motion_delta < BALL_MIN_MOTION_DELTA_PX: | |
| + | return None | |
| + | if best.y + BALL_MAX_UPWARD_STEP_PX < prev_y: | |
| + | return None | |
| + | self.ball_track_active = True | |
| + | return best | |
| + | ||
| + | def _detect_pci(self, frame): | |
| + | h, w = frame.shape[:2] | |
| + | hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV) | |
| + | mask = cv2.inRange(hsv, PCI_HSV_LOW, PCI_HSV_HIGH) | |
| + | x0, x1 = int(w * PCI_SEARCH_X[0]), int(w * PCI_SEARCH_X[1]) | |
| + | y0, y1 = int(h * PCI_SEARCH_Y[0]), int(h * PCI_SEARCH_Y[1]) | |
| + | window = np.zeros_like(mask) | |
| + | window[y0:y1, x0:x1] = 255 | |
| + | mask = cv2.bitwise_and(mask, window) | |
| + | mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, np.ones((3, 3), np.uint8), 1) | |
| + | ys, xs = np.where(mask > 0) | |
| + | if len(xs) < PCI_MIN_GREEN_PX: | |
| + | return None | |
| + | return float(xs.mean()), float(ys.mean()) | |
| + | ||
| + | def _try_fit(self, plate_y_px: float): | |
| + | trail = list(self.trail)[-FIT_USE_LAST_N:] | |
| + | if len(trail) < MIN_FIT_POINTS: | |
| + | return None | |
| + | ys = np.array([d.y for d in trail], dtype=np.float64) | |
| + | if ys.max() - ys.min() < 30: | |
| + | return None | |
| + | if (ys[-1] - ys[0]) < -30: | |
| + | return None | |
| + | t0 = trail[0].ts_ns | |
| + | ts = np.array([(d.ts_ns - t0) / 1e9 for d in trail], dtype=np.float64) | |
| + | xs = np.array([d.x for d in trail], dtype=np.float64) | |
| + | try: | |
| + | ay, by, cy = np.polyfit(ts, ys, 2) | |
| + | bx, cx = np.polyfit(ts, xs, 1) | |
| + | except Exception: | |
| + | return None | |
| + | disc = by * by - 4 * ay * (cy - plate_y_px) | |
| + | if disc < 0 or abs(ay) < 1e-6: | |
| + | return None | |
| + | sqrt_d = float(np.sqrt(disc)) | |
| + | candidates = [(-by + sqrt_d) / (2 * ay), (-by - sqrt_d) / (2 * ay)] | |
| + | future = [t for t in candidates if t > ts[-1]] | |
| + | if not future: | |
| + | return None | |
| + | t_cross = min(future) | |
| + | plate_x = float(bx * t_cross + cx) | |
| + | eta_ms = float((t_cross - ts[-1]) * 1000.0) | |
| + | if eta_ms < 0 or eta_ms > 2000: | |
| + | return None | |
| + | return plate_x, eta_ms | |
| + | ||
| + | def process(self, frame): | |
| + | ||
| + | full_h, full_w = frame.shape[:2] | |
| + | small = cv2.resize(frame, (full_w // 2, full_h // 2), interpolation=cv2.INTER_AREA) | |
| + | h, w = small.shape[:2] | |
| + | ts_ns = time.time_ns() | |
| + | plate_y_px = h * PLATE_Y_FRAC | |
| + | scale = 2.0 | |
| + | ||
| + | sz_cx = w * SZ_CX_FRAC | |
| + | sz_cy = h * SZ_CY_FRAC | |
| + | sz_w = w * SZ_W_FRAC | |
| + | sz_h = h * SZ_H_FRAC | |
| + | sz_left = sz_cx - sz_w / 2 | |
| + | sz_right = sz_cx + sz_w / 2 | |
| + | sz_top = sz_cy - sz_h / 2 | |
| + | sz_bottom = sz_cy + sz_h / 2 | |
| + | ||
| + | self.last_pci = (sz_cx, sz_cy) | |
| + | ||
| + | ball = self._detect_ball(small, ts_ns) | |
| + | ball_raw = ball | |
| + | if ball_raw is not None: | |
| + | if self.last_ball_pos is not None and self.last_ball_ts: | |
| + | dt = (ts_ns - self.last_ball_ts) / 1e9 | |
| + | if 0.0 < dt <= BALL_MOTION_MAX_GAP_MS / 1000.0: | |
| + | vx = (ball_raw.x - self.last_ball_pos[0]) / dt | |
| + | vy = (ball_raw.y - self.last_ball_pos[1]) / dt | |
| + | self.last_ball_vel = (vx, vy) | |
| + | self.last_ball_pos = (ball_raw.x, ball_raw.y) | |
| + | self.last_ball_r = ball_raw.r | |
| + | self.last_ball_ts = ts_ns | |
| + | ||
| + | if ball is not None: | |
| + | if self.last_det_ns is not None and (ts_ns - self.last_det_ns) > PITCH_GAP_MS * 1e6: | |
| + | self.trail.clear() | |
| + | self.pitch_id += 1 | |
| + | if self.trail: | |
| + | last = self.trail[-1] | |
| + | d = ((ball.x - last.x) ** 2 + (ball.y - last.y) ** 2) ** 0.5 | |
| + | if d > MAX_STEP_PX: | |
| + | self.trail.clear() | |
| + | self.pitch_id += 1 | |
| + | if not self.trail and ball.y > PITCH_START_MAX_Y: | |
| + | ball = None | |
| + | if ball is not None: | |
| + | self.trail.append(ball) | |
| + | self.last_det_ns = ts_ns | |
| + | cutoff = ts_ns - int(TRAJ_WINDOW_S * 1e9) | |
| + | while self.trail and self.trail[0].ts_ns < cutoff: | |
| + | self.trail.popleft() | |
| + | ||
| + | pred = self._try_fit(plate_y_px) if ball is not None else None | |
| + | ||
| + | aim_x = 0.0 | |
| + | aim_y = 0.0 | |
| + | press_contact = 0 | |
| + | press_power = 0 | |
| + | eta_ms = 0 | |
| + | pred_good = 0 | |
| + | target_x = None | |
| + | target_y = None | |
| + | ball_ready_to_swing = False | |
| + | pred_ready_to_swing = False | |
| + | ||
| + | pci_ref = self.last_pci if self.last_pci is not None else (w / 2.0, h * 0.55) | |
| + | ||
| + | if pred is not None: | |
| + | plate_x, eta_ms_f = pred | |
| + | eta_ms = int(max(0, min(32767, eta_ms_f))) | |
| + | pred_good = 1 | |
| + | target_x = plate_x | |
| + | target_y = plate_y_px | |
| + | elif self.last_ball_pos is not None: | |
| + | if (ts_ns - self.last_ball_ts) <= BALL_MEMORY_MS * 1_000_000: | |
| + | age_s = (ts_ns - self.last_ball_ts) / 1e9 | |
| + | vx, vy = self.last_ball_vel | |
| + | if vy > 0 and age_s <= BALL_SWING_MEMORY_MS / 1000.0: | |
| + | target_x = self.last_ball_pos[0] + vx * age_s | |
| + | target_y = self.last_ball_pos[1] + vy * age_s | |
| + | else: | |
| + | target_x = self.last_ball_pos[0] | |
| + | target_y = self.last_ball_pos[1] | |
| + | ||
| + | if target_x is not None: | |
| + | dx = target_x - pci_ref[0] | |
| + | dy = target_y - pci_ref[1] | |
| + | if abs(dx) > AIM_DEADZONE_PX: | |
| + | aim_x = max(-AIM_MAX_STICK, min(AIM_MAX_STICK, AIM_GAIN_X * dx)) | |
| + | if abs(dy) > AIM_DEADZONE_PX: | |
| + | aim_y = max(-AIM_MAX_STICK, min(AIM_MAX_STICK, AIM_GAIN_Y * dy)) | |
| + | ||
| + | residual = (dx ** 2 + dy ** 2) ** 0.5 | |
| + | ||
| + | ball_x = self.last_ball_pos[0] if self.last_ball_pos is not None else 0 | |
| + | ball_y = self.last_ball_pos[1] if self.last_ball_pos is not None else 0 | |
| + | ball_age_s = (ts_ns - self.last_ball_ts) / 1e9 if self.last_ball_ts else 999.0 | |
| + | vx, vy = self.last_ball_vel | |
| + | if self.last_ball_pos is not None and vy > 0 and ball_age_s <= BALL_SWING_MEMORY_MS / 1000.0: | |
| + | ball_x = self.last_ball_pos[0] + vx * ball_age_s | |
| + | ball_y = self.last_ball_pos[1] + vy * ball_age_s | |
| + | ball_recent = (self.last_ball_pos is not None and | |
| + | (ts_ns - self.last_ball_ts) <= BALL_SWING_MEMORY_MS * 1_000_000) | |
| + | cooled_down = (ts_ns - self.last_swing_ts) > 1_500_000_000 | |
| + | zone_pad_x = SWING_ZONE_X_PAD_PX / scale | |
| + | zone_pad_y = SWING_ZONE_Y_PAD_PX / scale | |
| + | swing_lead = SWING_LEAD_PX / scale | |
| + | in_zone_x = (sz_left - zone_pad_x) < ball_x < (sz_right + zone_pad_x) | |
| + | in_zone_y = (sz_top - zone_pad_y) < ball_y < (sz_bottom + zone_pad_y) | |
| + | ||
| + | hittable_size = self.last_ball_r >= SWING_BALL_MIN_R | |
| + | ball_ready_to_swing = self.ball_track_active and ball_recent and in_zone_x and in_zone_y and hittable_size | |
| + | pred_ready_to_swing = False | |
| + | if (ball_ready_to_swing or pred_ready_to_swing) and cooled_down: | |
| + | self.last_swing_ts = ts_ns | |
| + | self.swing_hold_until = ts_ns + SWING_HOLD_MS * 1_000_000 | |
| + | ||
| + | if ts_ns < self.swing_hold_until: | |
| + | press_contact = 1 | |
| + | ||
| + | armed = 1 | |
| + | in_flight = 1 if (self.trail and ball is not None) else 0 | |
| + | debug_flags = (1 if self.last_pci else 0) | (2 if ball else 0) | (4 if pred_good else 0) | |
| + | ||
| + | if DEBUG_FORCE_AIM: | |
| + | aim_x = 25.0 | |
| + | aim_y = 0.0 | |
| + | ||
| + | self.gcvdata[0:4] = _fix32(aim_x) | |
| + | self.gcvdata[4:8] = _fix32(aim_y) | |
| + | self.gcvdata[8:10] = _int16(armed) | |
| + | self.gcvdata[10:12] = _int16(in_flight) | |
| + | self.gcvdata[12:14] = _int16(press_contact) | |
| + | self.gcvdata[14:16] = _int16(press_power) | |
| + | self.gcvdata[16:18] = _int16(eta_ms) | |
| + | self.gcvdata[18:20] = _int16(debug_flags) | |
| + | ||
| + | box_color = (0, 255, 255) if press_contact else ((0, 255, 0) if ball_ready_to_swing or pred_ready_to_swing else (190, 190, 190)) | |
| + | cv2.rectangle(frame, | |
| + | (int(sz_left * scale), int(sz_top * scale)), | |
| + | (int(sz_right * scale), int(sz_bottom * scale)), | |
| + | box_color, 4) | |
| + | if ball is not None: | |
| + | cv2.circle(frame, (int(ball.x * scale), int(ball.y * scale)), max(12, int(ball.r * scale)), (0, 255, 255), 2) | |
| + | if pred is not None: | |
| + | px, _ = pred | |
| + | cv2.drawMarker(frame, (int(px * scale), int(plate_y_px * scale)), (0, 128, 255), cv2.MARKER_CROSS, 48, 3) | |
| + | cv2.putText(frame, f"eta {eta_ms}ms aim=({aim_x:+.1f},{aim_y:+.1f}) contact={press_contact}", | |
| + | (20, 50), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2, cv2.LINE_AA) | |
| + | ||
| + | hud = f"v41 {'SWING' if press_contact else 'READY' if (ball_ready_to_swing or pred_ready_to_swing) else 'LOCK' if self.ball_track_active else 'WAIT'}" | |
| + | cv2.rectangle(frame, (16, full_h - 92), (430, full_h - 18), (0, 0, 0), -1) | |
| + | cv2.putText(frame, hud, (30, full_h - 42), | |
| + | cv2.FONT_HERSHEY_SIMPLEX, 1.4, box_color, 4, cv2.LINE_AA) | |
| + | ||
| + | return (frame, self.gcvdata) | |
| + | ||
| + | try: | |
| + | import importlib.util as _importlib_util | |
| + | from pathlib import Path as _Path | |
| + | ||
| + | _MODEL_PATH = _Path(__file__).resolve().parents[1] / "configs" / "models" / "mlb26_ball.onnx" | |
| + | _YOLO_PATH = _Path(__file__).with_name("mlb26_gcv_yolo.py") | |
| + | if _MODEL_PATH.exists() and _YOLO_PATH.exists(): | |
| + | _spec = _importlib_util.spec_from_file_location("_mlb26_gcv_yolo", str(_YOLO_PATH)) | |
| + | if _spec is not None and _spec.loader is not None: | |
| + | _module = _importlib_util.module_from_spec(_spec) | |
| + | _spec.loader.exec_module(_module) | |
| + | GCVWorker = _module.GCVWorker | |
| + | except Exception: | |
| + | pass |
| @@ -0,0 +1,642 @@ | ||
| + | """YOLO/ONNX GCVWorker for MLB The Show 26. | |
| + | ||
| + | Use this after training/exporting configs/models/mlb26_ball.onnx. It keeps the | |
| + | same 20-byte GCV packet contract as mlb26_gcv.py and mlb26_bridge.gpc. | |
| + | """ | |
| + | from __future__ import annotations | |
| + | ||
| + | import os | |
| + | import struct | |
| + | import sys | |
| + | import time | |
| + | import ctypes | |
| + | from pathlib import Path | |
| + | ||
| + | import cv2 | |
| + | import numpy as np | |
| + | ||
| + | def _env_roots() -> tuple[Path, ...]: | |
| + | extra = os.environ.get("PITCH_TRACKER_CV_ROOT") | |
| + | if not extra: | |
| + | return () | |
| + | return tuple(Path(p.strip()) for p in extra.split(os.pathsep) if p.strip()) | |
| + | ||
| + | ||
| + | PROJECT_ROOTS = (Path(__file__).resolve().parents[1],) + _env_roots() | |
| + | PREFERRED_ROOTS = PROJECT_ROOTS | |
| + | MODEL_NAMES = ( | |
| + | ("mlb26_ball_320.onnx", 320), | |
| + | ("mlb26_ball_416.onnx", 416), | |
| + | ("mlb26_ball_512.onnx", 512), | |
| + | ("mlb26_ball.onnx", 640), | |
| + | ) | |
| + | CV2_MODEL_NAMES = ( | |
| + | ("mlb26_ball_320.onnx", 320), | |
| + | ("mlb26_ball_416.onnx", 416), | |
| + | ("mlb26_ball_512.onnx", 512), | |
| + | ("mlb26_ball.onnx", 640), | |
| + | ) | |
| + | ||
| + | YOLO_INPUT = 320 | |
| + | YOLO_CONF = 0.18 | |
| + | YOLO_NMS = 0.45 | |
| + | ||
| + | DETECT_X_FRAC = (0.31, 0.66) | |
| + | DETECT_Y_FRAC = (0.18, 0.84) | |
| + | ACQUIRE_X_FRAC = (0.40, 0.60) | |
| + | ACQUIRE_Y_FRAC = (0.20, 0.52) | |
| + | ACQUIRE_MAX_DIAM_FRAC = 0.0120 | |
| + | MAX_LOCK_STEP_FRAC = 0.090 | |
| + | MAX_LOCK_AGE_MS = 180 | |
| + | MIN_DOWNWARD_STEP_PX = -10.0 | |
| + | ||
| + | SZ_CX_FRAC = 0.492 | |
| + | SZ_CY_FRAC = 0.518 | |
| + | SZ_W_FRAC = 0.203 | |
| + | SZ_H_FRAC = 0.555 | |
| + | SWING_PAD_PX = 0 | |
| + | SWING_MIN_DIAM_FRAC = 0.004 | |
| + | SWING_HOLD_MS = 280 | |
| + | SWING_COOLDOWN_MS = 500 | |
| + | SWING_HIT_DIAM_FRAC = 0.0070 | |
| + | SWING_MIN_GROWTH_PX_S = 3.0 | |
| + | SWING_ETA_FIRE_MS = 180 | |
| + | MIN_TRACK_FRAMES_FOR_SWING = 1 | |
| + | PITCH_END_TIMEOUT_MS = 260 | |
| + | ||
| + | AIM_GAIN_X = 0.19 | |
| + | AIM_GAIN_Y = 0.16 | |
| + | AIM_DEADZONE_PX = 6.0 | |
| + | AIM_MAX_STICK = 92.0 | |
| + | BALL_MEMORY_MS = 130 | |
| + | BALL_SWING_MEMORY_MS = 110 | |
| + | INFER_EVERY_N_FRAMES = 2 | |
| + | ||
| + | DEBUG_FORCE_AIM = False | |
| + | ||
| + | CUDA_DLL_NAMES = ( | |
| + | "zlibwapi.dll", | |
| + | "cudart64_12.dll", | |
| + | "cublas64_12.dll", | |
| + | "cublasLt64_12.dll", | |
| + | "cufft64_11.dll", | |
| + | "cufftw64_11.dll", | |
| + | "curand64_10.dll", | |
| + | "cusparse64_12.dll", | |
| + | "cusolver64_11.dll", | |
| + | "cusolverMg64_11.dll", | |
| + | "nvJitLink_120_0.dll", | |
| + | "nvrtc64_120_0.dll", | |
| + | "nvrtc-builtins64_126.dll", | |
| + | "cudnn64_9.dll", | |
| + | "cudnn_ops64_9.dll", | |
| + | "cudnn_adv64_9.dll", | |
| + | "cudnn_cnn64_9.dll", | |
| + | "cudnn_graph64_9.dll", | |
| + | "cudnn_heuristic64_9.dll", | |
| + | "cudnn_engines_runtime_compiled64_9.dll", | |
| + | "cudnn_engines_precompiled64_9.dll", | |
| + | ) | |
| + | ||
| + | ORT_DLL_NAMES = () | |
| + | ||
| + | def _fix32(value: float) -> bytes: | |
| + | v = int(round(value * 65536.0)) | |
| + | v = max(-2**31, min(2**31 - 1, v)) | |
| + | return struct.pack(">i", v) | |
| + | ||
| + | def _int16(value: int) -> bytes: | |
| + | v = max(-32768, min(32767, int(value))) | |
| + | return struct.pack(">h", v) | |
| + | ||
| + | class GCVWorker: | |
| + | def __init__(self, width, height): | |
| + | os.chdir(os.path.dirname(__file__)) | |
| + | self.width = width | |
| + | self.height = height | |
| + | self.gcvdata = bytearray(20) | |
| + | self.net = None | |
| + | self.session = None | |
| + | self.input_name = None | |
| + | self.backend_name = "none" | |
| + | self.model_path = None | |
| + | self.yolo_input = YOLO_INPUT | |
| + | self.model_error = "" | |
| + | self.ort_error = "" | |
| + | self.runtime_site = "" | |
| + | self.runtime_dll_dirs = [] | |
| + | self.preload_error = "" | |
| + | self.last_swing_ts = 0 | |
| + | self.swing_hold_until = 0 | |
| + | self.last_ball = None | |
| + | self.prev_ball = None | |
| + | self.locked_ball = None | |
| + | self.last_ball_vel = (0.0, 0.0) | |
| + | self.pitch_active = False | |
| + | self.pitch_swinged = False | |
| + | self.track_frames = 0 | |
| + | self.pitch_last_seen_ns = 0 | |
| + | self.frame_id = 0 | |
| + | self._load_model() | |
| + | ||
| + | def _write_debug(self, message: str) -> None: | |
| + | lines = [ | |
| + | f"time={time.strftime('%Y-%m-%d %H:%M:%S')}", | |
| + | f"message={message}", | |
| + | f"file={__file__}", | |
| + | f"cwd={os.getcwd()}", | |
| + | f"python={sys.version}", | |
| + | f"executable={sys.executable}", | |
| + | f"model_path={self.model_path}", | |
| + | f"backend={self.backend_name}", | |
| + | f"yolo_input={self.yolo_input}", | |
| + | f"runtime_site={self.runtime_site}", | |
| + | f"runtime_dll_dirs={self.runtime_dll_dirs}", | |
| + | f"preload_error={self.preload_error}", | |
| + | f"ort_error={self.ort_error}", | |
| + | ] | |
| + | for root in PROJECT_ROOTS: | |
| + | try: | |
| + | if root.exists(): | |
| + | (root / "gcv_yolo_runtime_debug.txt").write_text("\n".join(lines), encoding="utf-8") | |
| + | except Exception: | |
| + | pass | |
| + | ||
| + | @staticmethod | |
| + | def _runtime_site_candidates(): | |
| + | seen = set() | |
| + | for root in PREFERRED_ROOTS: | |
| + | if str(root) in seen: | |
| + | continue | |
| + | seen.add(str(root)) | |
| + | yield root / ".venv_yolo" / "Lib" / "site-packages" | |
| + | yield root / ".venv" / "Lib" / "site-packages" | |
| + | ||
| + | @staticmethod | |
| + | def _add_dll_dirs(dll_dirs): | |
| + | existing = [] | |
| + | for dll_dir in dll_dirs: | |
| + | if dll_dir and dll_dir.exists() and dll_dir not in existing: | |
| + | existing.append(dll_dir) | |
| + | if existing: | |
| + | os.environ["PATH"] = os.pathsep.join(str(d) for d in existing) + os.pathsep + os.environ.get("PATH", "") | |
| + | if hasattr(os, "add_dll_directory"): | |
| + | for dll_dir in existing: | |
| + | try: | |
| + | os.add_dll_directory(str(dll_dir)) | |
| + | except Exception: | |
| + | pass | |
| + | return existing | |
| + | ||
| + | @staticmethod | |
| + | def _preload_dlls(dll_dirs, names): | |
| + | errors = [] | |
| + | for name in names: | |
| + | for dll_dir in dll_dirs: | |
| + | dll_path = dll_dir / name | |
| + | if not dll_path.exists(): | |
| + | continue | |
| + | try: | |
| + | ctypes.WinDLL(str(dll_path)) | |
| + | except Exception as exc: | |
| + | errors.append(f"{name}:{str(exc)[:70]}") | |
| + | break | |
| + | return "; ".join(errors[:6]) | |
| + | ||
| + | def __del__(self): | |
| + | try: | |
| + | del self.gcvdata | |
| + | except Exception: | |
| + | pass | |
| + | ||
| + | def _load_model(self) -> None: | |
| + | try: | |
| + | for project_root in PREFERRED_ROOTS: | |
| + | model_dir = project_root / "configs" / "models" | |
| + | for model_name, input_size in MODEL_NAMES: | |
| + | model_path = model_dir / model_name | |
| + | if model_path.exists(): | |
| + | self.model_path = model_path | |
| + | self.yolo_input = input_size | |
| + | break | |
| + | if self.model_path is not None: | |
| + | break | |
| + | if self.model_path is None: | |
| + | self.model_error = "missing mlb26 ball ONNX" | |
| + | return | |
| + | yolo_site = None | |
| + | for candidate in self._runtime_site_candidates(): | |
| + | if (candidate / "onnxruntime").exists(): | |
| + | yolo_site = candidate | |
| + | break | |
| + | yolo_capi = yolo_site / "onnxruntime" / "capi" if yolo_site else Path() | |
| + | torch_lib = yolo_site / "torch" / "lib" if yolo_site else Path() | |
| + | runtime_dlls = Path() | |
| + | for root in PREFERRED_ROOTS: | |
| + | candidate = root / "runtime_dlls" | |
| + | if candidate.exists(): | |
| + | runtime_dlls = candidate | |
| + | break | |
| + | if yolo_site and str(yolo_site) not in sys.path: | |
| + | sys.path.insert(0, str(yolo_site)) | |
| + | self.runtime_site = str(yolo_site) if yolo_site else "" | |
| + | dll_dirs = self._add_dll_dirs((runtime_dlls, torch_lib, yolo_capi)) | |
| + | self.runtime_dll_dirs = [str(d) for d in dll_dirs] | |
| + | preload_errors = [] | |
| + | cuda_preload = self._preload_dlls(dll_dirs, CUDA_DLL_NAMES) | |
| + | if cuda_preload: | |
| + | preload_errors.append(cuda_preload) | |
| + | ort_preload = self._preload_dlls(dll_dirs, ORT_DLL_NAMES) | |
| + | if ort_preload: | |
| + | preload_errors.append(ort_preload) | |
| + | self.preload_error = " | ".join(preload_errors)[:240] | |
| + | try: | |
| + | import onnxruntime as ort | |
| + | ||
| + | providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] | |
| + | self.session = ort.InferenceSession(str(self.model_path), providers=providers) | |
| + | self.input_name = self.session.get_inputs()[0].name | |
| + | active = self.session.get_providers() | |
| + | self.backend_name = "ort" if "CUDAExecutionProvider" in active else "ortcpu" | |
| + | self.model_error = "" | |
| + | self._write_debug(f"loaded onnxruntime providers={active}") | |
| + | return | |
| + | except Exception as exc: | |
| + | self.ort_error = str(exc)[:120] | |
| + | self.session = None | |
| + | self.input_name = None | |
| + | for project_root in PREFERRED_ROOTS: | |
| + | model_dir = project_root / "configs" / "models" | |
| + | for model_name, input_size in CV2_MODEL_NAMES: | |
| + | model_path = model_dir / model_name | |
| + | if model_path.exists(): | |
| + | self.model_path = model_path | |
| + | self.yolo_input = input_size | |
| + | break | |
| + | if self.model_path is not None and self.yolo_input == input_size: | |
| + | break | |
| + | self.net = cv2.dnn.readNetFromONNX(str(self.model_path)) | |
| + | self.net.setPreferableBackend(cv2.dnn.DNN_BACKEND_OPENCV) | |
| + | self.net.setPreferableTarget(cv2.dnn.DNN_TARGET_CPU) | |
| + | self.backend_name = "cv2" | |
| + | self.model_error = "" | |
| + | self._write_debug("loaded cv2 fallback") | |
| + | except Exception as exc: | |
| + | self.net = None | |
| + | self.session = None | |
| + | self.model_error = str(exc)[:80] | |
| + | ||
| + | @staticmethod | |
| + | def _rows_from_output(output): | |
| + | out = output[0] if isinstance(output, (tuple, list)) else output | |
| + | out = np.asarray(out) | |
| + | if out.ndim == 3: | |
| + | out = out[0] | |
| + | if out.ndim != 2: | |
| + | return np.empty((0, 0), dtype=np.float32) | |
| + | if out.shape[0] <= 85 and out.shape[0] < out.shape[1]: | |
| + | out = out.T | |
| + | return out | |
| + | ||
| + | def _detect_candidates(self, frame): | |
| + | if self.net is None and self.session is None: | |
| + | return [] | |
| + | ||
| + | h, w = frame.shape[:2] | |
| + | x0 = max(0, int(w * DETECT_X_FRAC[0])) | |
| + | x1 = min(w, int(w * DETECT_X_FRAC[1])) | |
| + | y0 = max(0, int(h * DETECT_Y_FRAC[0])) | |
| + | y1 = min(h, int(h * DETECT_Y_FRAC[1])) | |
| + | if x1 <= x0 or y1 <= y0: | |
| + | return [] | |
| + | ||
| + | roi = frame[y0:y1, x0:x1] | |
| + | roi_h, roi_w = roi.shape[:2] | |
| + | blob = cv2.dnn.blobFromImage( | |
| + | roi, | |
| + | scalefactor=1.0 / 255.0, | |
| + | size=(self.yolo_input, self.yolo_input), | |
| + | mean=(0, 0, 0), | |
| + | swapRB=True, | |
| + | crop=False, | |
| + | ) | |
| + | if self.session is not None: | |
| + | output = self.session.run(None, {self.input_name: blob}) | |
| + | else: | |
| + | self.net.setInput(blob) | |
| + | output = self.net.forward() | |
| + | rows = self._rows_from_output(output) | |
| + | ||
| + | boxes = [] | |
| + | scores = [] | |
| + | for row in rows: | |
| + | if row.shape[0] < 5: | |
| + | continue | |
| + | if row.shape[0] == 5: | |
| + | score = float(row[4]) | |
| + | elif row.shape[0] == 6: | |
| + | score = float(row[4]) if row[5] <= 1.0 else float(row[5]) | |
| + | else: | |
| + | score = float(np.max(row[4:])) | |
| + | if score < YOLO_CONF: | |
| + | continue | |
| + | cx, cy, bw, bh = [float(v) for v in row[:4]] | |
| + | aspect = bw / max(1.0, bh) | |
| + | if aspect < 0.42 or aspect > 2.35: | |
| + | continue | |
| + | left = (cx - bw / 2.0) * roi_w / self.yolo_input | |
| + | top = (cy - bh / 2.0) * roi_h / self.yolo_input | |
| + | ww = bw * roi_w / self.yolo_input | |
| + | hh = bh * roi_h / self.yolo_input | |
| + | if ww < 3 or hh < 3: | |
| + | continue | |
| + | if ww > roi_w * 0.16 or hh > roi_h * 0.16: | |
| + | continue | |
| + | boxes.append([int(left), int(top), int(ww), int(hh)]) | |
| + | scores.append(score) | |
| + | ||
| + | if not boxes: | |
| + | return [] | |
| + | keep = cv2.dnn.NMSBoxes(boxes, scores, YOLO_CONF, YOLO_NMS) | |
| + | if len(keep) == 0: | |
| + | return [] | |
| + | candidates = [] | |
| + | for idx in np.array(keep).reshape(-1): | |
| + | bx, by, bw, bh = boxes[int(idx)] | |
| + | conf = scores[int(idx)] | |
| + | candidates.append(( | |
| + | float(x0 + bx + bw / 2.0), | |
| + | float(y0 + by + bh / 2.0), | |
| + | float(bw), | |
| + | float(bh), | |
| + | float(conf), | |
| + | )) | |
| + | return candidates | |
| + | ||
| + | def _choose_ball(self, candidates, frame_w, frame_h, ts_ns): | |
| + | if not candidates: | |
| + | if self.last_ball and (ts_ns - self.last_ball[5]) > MAX_LOCK_AGE_MS * 1_000_000: | |
| + | self.locked_ball = None | |
| + | return None | |
| + | ||
| + | acquire_left = frame_w * ACQUIRE_X_FRAC[0] | |
| + | acquire_right = frame_w * ACQUIRE_X_FRAC[1] | |
| + | acquire_top = frame_h * ACQUIRE_Y_FRAC[0] | |
| + | acquire_bottom = frame_h * ACQUIRE_Y_FRAC[1] | |
| + | max_step = frame_w * MAX_LOCK_STEP_FRAC | |
| + | if self.locked_ball is not None: | |
| + | lx, ly, lw, lh, lconf, lts = self.locked_ball | |
| + | age_ms = (ts_ns - lts) / 1_000_000 | |
| + | if age_ms <= MAX_LOCK_AGE_MS: | |
| + | scored = [] | |
| + | for cand in candidates: | |
| + | cx, cy, cw, ch, conf = cand | |
| + | dx = cx - lx | |
| + | dy = cy - ly | |
| + | dist = (dx * dx + dy * dy) ** 0.5 | |
| + | size_ratio = max(cw, ch) / max(1.0, max(lw, lh)) | |
| + | if dist <= max_step and dy >= MIN_DOWNWARD_STEP_PX and 0.35 <= size_ratio <= 3.25: | |
| + | scored.append((dist - conf * 35.0, cand)) | |
| + | if scored: | |
| + | chosen = min(scored, key=lambda item: item[0])[1] | |
| + | self.locked_ball = (*chosen, ts_ns) | |
| + | return chosen | |
| + | ||
| + | acquire = [ | |
| + | c for c in candidates | |
| + | if acquire_left <= c[0] <= acquire_right and acquire_top <= c[1] <= acquire_bottom | |
| + | ] | |
| + | if not acquire: | |
| + | loose = [ | |
| + | c for c in candidates | |
| + | if c[1] <= frame_h * (ACQUIRE_Y_FRAC[1] + 0.10) and max(c[2], c[3]) <= frame_w * 0.020 | |
| + | ] | |
| + | if not loose: | |
| + | return None | |
| + | chosen = max(loose, key=lambda c: c[4]) | |
| + | self.locked_ball = (*chosen, ts_ns) | |
| + | return chosen | |
| + | chosen = max(acquire, key=lambda c: c[4]) | |
| + | self.locked_ball = (*chosen, ts_ns) | |
| + | return chosen | |
| + | ||
| + | @staticmethod | |
| + | def _segment_hits_box(prev, cur, left, right, top, bottom): | |
| + | if prev is None or cur is None: | |
| + | return False | |
| + | px, py = prev[:2] | |
| + | cx, cy = cur[:2] | |
| + | if cy <= py: | |
| + | return False | |
| + | if left <= px <= right and top <= py <= bottom: | |
| + | return True | |
| + | if left <= cx <= right and top <= cy <= bottom: | |
| + | return True | |
| + | if cy < top or py > bottom: | |
| + | return False | |
| + | for y in (top, bottom): | |
| + | if py <= y <= cy: | |
| + | t = (y - py) / max(1.0, cy - py) | |
| + | x = px + (cx - px) * t | |
| + | if left <= x <= right: | |
| + | return True | |
| + | return False | |
| + | ||
| + | @staticmethod | |
| + | def _ball_in_zone(ball, sz_left, sz_right, sz_top, sz_bottom): | |
| + | bx, by = ball[0], ball[1] | |
| + | return sz_left <= bx <= sz_right and sz_top <= by <= sz_bottom | |
| + | ||
| + | def _compute_swing_ready(self, prev_ball, cur_ball, sz_left, sz_right, sz_top, sz_bottom, full_w): | |
| + | ball_diam = max(cur_ball[2], cur_ball[3]) | |
| + | min_track_diam = max(6.0, full_w * SWING_MIN_DIAM_FRAC) | |
| + | hit_diam = max(min_track_diam + 2.0, full_w * SWING_HIT_DIAM_FRAC) | |
| + | if ball_diam < min_track_diam: | |
| + | return False, 0, 0 | |
| + | if not self._ball_in_zone(cur_ball, sz_left - SWING_PAD_PX, sz_right + SWING_PAD_PX, sz_top - SWING_PAD_PX, sz_bottom + SWING_PAD_PX): | |
| + | return False, 0, 0 | |
| + | ||
| + | close_by_size = ball_diam >= hit_diam | |
| + | eta_ms = 0 | |
| + | pred_good = 0 | |
| + | if prev_ball is not None: | |
| + | prev_diam = max(prev_ball[2], prev_ball[3]) | |
| + | dt_s = max(1e-6, (cur_ball[5] - prev_ball[5]) / 1e9) | |
| + | growth_px_s = (ball_diam - prev_diam) / dt_s | |
| + | if growth_px_s >= SWING_MIN_GROWTH_PX_S and ball_diam < hit_diam: | |
| + | eta = (hit_diam - ball_diam) * 1000.0 / max(growth_px_s, 1.0) | |
| + | if 0.0 <= eta <= 32767.0: | |
| + | eta_ms = int(eta) | |
| + | pred_good = 1 | |
| + | close_soon = 0 < eta_ms <= SWING_ETA_FIRE_MS | |
| + | ready = close_by_size or close_soon | |
| + | if close_by_size: | |
| + | pred_good = 1 | |
| + | return ready, eta_ms, pred_good | |
| + | ||
| + | @staticmethod | |
| + | def _in_acquire_window(ball, full_w, full_h): | |
| + | bx, by, bw, bh = ball[0], ball[1], ball[2], ball[3] | |
| + | max_d = full_w * ACQUIRE_MAX_DIAM_FRAC | |
| + | return ( | |
| + | full_w * ACQUIRE_X_FRAC[0] <= bx <= full_w * ACQUIRE_X_FRAC[1] | |
| + | and full_h * ACQUIRE_Y_FRAC[0] <= by <= full_h * ACQUIRE_Y_FRAC[1] | |
| + | and max(bw, bh) <= max_d | |
| + | ) | |
| + | ||
| + | def process(self, frame): | |
| + | ts_ns = time.time_ns() | |
| + | full_h, full_w = frame.shape[:2] | |
| + | self.frame_id += 1 | |
| + | do_infer = (self.last_ball is None) or (self.frame_id % max(1, INFER_EVERY_N_FRAMES) == 0) | |
| + | if do_infer: | |
| + | candidates = self._detect_candidates(frame) | |
| + | ball = self._choose_ball(candidates, full_w, full_h, ts_ns) | |
| + | else: | |
| + | candidates = [] | |
| + | ball = None | |
| + | ||
| + | sz_cx = full_w * SZ_CX_FRAC | |
| + | sz_cy = full_h * SZ_CY_FRAC | |
| + | sz_w = full_w * SZ_W_FRAC | |
| + | sz_h = full_h * SZ_H_FRAC | |
| + | sz_left = sz_cx - sz_w / 2 | |
| + | sz_right = sz_cx + sz_w / 2 | |
| + | sz_top = sz_cy - sz_h / 2 | |
| + | sz_bottom = sz_cy + sz_h / 2 | |
| + | ||
| + | aim_x = 0.0 | |
| + | aim_y = 0.0 | |
| + | press_contact = 0 | |
| + | press_power = 0 | |
| + | eta_ms = 0 | |
| + | pred_good = 0 | |
| + | ready = False | |
| + | eval_ball = None | |
| + | is_real_ball = ball is not None | |
| + | ||
| + | if ball is not None: | |
| + | bx, by, bw, bh, conf = ball | |
| + | if self.last_ball is not None: | |
| + | px, py, pw, ph, pconf, pts_ns = self.last_ball | |
| + | dt = (ts_ns - pts_ns) / 1e9 | |
| + | if 0.0 < dt <= BALL_MEMORY_MS / 1000.0: | |
| + | self.last_ball_vel = ((bx - px) / dt, (by - py) / dt) | |
| + | if self.pitch_active: | |
| + | self.track_frames += 1 | |
| + | elif ( | |
| + | self._in_acquire_window((bx, by, bw, bh, conf, ts_ns), full_w, full_h) | |
| + | or (by <= (sz_top + 0.62 * sz_h) and max(bw, bh) <= full_w * 0.018) | |
| + | ): | |
| + | self.pitch_active = True | |
| + | self.pitch_swinged = False | |
| + | self.track_frames = 1 | |
| + | else: | |
| + | self.track_frames = 0 | |
| + | self.pitch_last_seen_ns = ts_ns | |
| + | if self.pitch_active and by > (sz_bottom + 0.25 * sz_h): | |
| + | self.pitch_active = False | |
| + | self.pitch_swinged = False | |
| + | self.track_frames = 0 | |
| + | ready, eta_local, pred_local = self._compute_swing_ready( | |
| + | self.last_ball, | |
| + | (bx, by, bw, bh, conf, ts_ns), | |
| + | sz_left, | |
| + | sz_right, | |
| + | sz_top, | |
| + | sz_bottom, | |
| + | full_w, | |
| + | ) | |
| + | if eta_local > 0: | |
| + | eta_ms = eta_local | |
| + | pred_good = max(pred_good, pred_local) | |
| + | dx = bx - sz_cx | |
| + | dy = by - sz_cy | |
| + | if abs(dx) > AIM_DEADZONE_PX: | |
| + | aim_x = max(-AIM_MAX_STICK, min(AIM_MAX_STICK, AIM_GAIN_X * dx)) | |
| + | if abs(dy) > AIM_DEADZONE_PX: | |
| + | aim_y = max(-AIM_MAX_STICK, min(AIM_MAX_STICK, AIM_GAIN_Y * dy)) | |
| + | cooled = (ts_ns - self.last_swing_ts) > SWING_COOLDOWN_MS * 1_000_000 | |
| + | if ( | |
| + | ready | |
| + | and cooled | |
| + | and is_real_ball | |
| + | and (not self.pitch_swinged or (ts_ns - self.last_swing_ts) > 900 * 1_000_000) | |
| + | and self.track_frames >= MIN_TRACK_FRAMES_FOR_SWING | |
| + | ): | |
| + | self.last_swing_ts = ts_ns | |
| + | self.swing_hold_until = ts_ns + SWING_HOLD_MS * 1_000_000 | |
| + | self.pitch_swinged = True | |
| + | self.prev_ball = self.last_ball | |
| + | self.last_ball = (bx, by, bw, bh, conf, ts_ns) | |
| + | eval_ball = self.last_ball | |
| + | elif self.last_ball is not None: | |
| + | if self.pitch_active and self.pitch_last_seen_ns > 0: | |
| + | if (ts_ns - self.pitch_last_seen_ns) > PITCH_END_TIMEOUT_MS * 1_000_000: | |
| + | self.pitch_active = False | |
| + | self.pitch_swinged = False | |
| + | self.track_frames = 0 | |
| + | bx, by, bw, bh, conf, last_ts = self.last_ball | |
| + | age_ns = ts_ns - last_ts | |
| + | if age_ns <= BALL_MEMORY_MS * 1_000_000: | |
| + | age_s = age_ns / 1e9 | |
| + | vx, vy = self.last_ball_vel | |
| + | bx = bx + vx * age_s | |
| + | by = by + vy * age_s | |
| + | dx = bx - sz_cx | |
| + | dy = by - sz_cy | |
| + | if abs(dx) > AIM_DEADZONE_PX: | |
| + | aim_x = max(-AIM_MAX_STICK, min(AIM_MAX_STICK, AIM_GAIN_X * dx)) | |
| + | if abs(dy) > AIM_DEADZONE_PX: | |
| + | aim_y = max(-AIM_MAX_STICK, min(AIM_MAX_STICK, AIM_GAIN_Y * dy)) | |
| + | eval_ball = (bx, by, bw, bh, conf, ts_ns) | |
| + | ||
| + | if ts_ns < self.swing_hold_until: | |
| + | press_contact = 1 | |
| + | ||
| + | if DEBUG_FORCE_AIM: | |
| + | aim_x = 25.0 | |
| + | ||
| + | armed = 1 | |
| + | in_flight = 1 if eval_ball is not None else 0 | |
| + | debug_flags = (2 if is_real_ball else 0) | (4 if pred_good else 0) | (8 if self.pitch_active else 0) | |
| + | ||
| + | self.gcvdata[0:4] = _fix32(aim_x) | |
| + | self.gcvdata[4:8] = _fix32(aim_y) | |
| + | self.gcvdata[8:10] = _int16(armed) | |
| + | self.gcvdata[10:12] = _int16(in_flight) | |
| + | self.gcvdata[12:14] = _int16(press_contact) | |
| + | self.gcvdata[14:16] = _int16(press_power) | |
| + | self.gcvdata[16:18] = _int16(eta_ms) | |
| + | self.gcvdata[18:20] = _int16(debug_flags) | |
| + | ||
| + | box_color = (0, 255, 255) if press_contact else ((0, 255, 0) if ready else (190, 190, 190)) | |
| + | cv2.rectangle(frame, (int(sz_left), int(sz_top)), (int(sz_right), int(sz_bottom)), box_color, 6) | |
| + | dx0 = int(full_w * DETECT_X_FRAC[0]) | |
| + | dx1 = int(full_w * DETECT_X_FRAC[1]) | |
| + | dy0 = int(full_h * DETECT_Y_FRAC[0]) | |
| + | dy1 = int(full_h * DETECT_Y_FRAC[1]) | |
| + | cv2.rectangle(frame, (dx0, dy0), (dx1, dy1), (80, 80, 80), 1) | |
| + | if eval_ball is not None: | |
| + | bx, by, bw, bh, conf = eval_ball[:5] | |
| + | cv2.circle(frame, (int(bx), int(by)), max(8, int(max(bw, bh) / 2)), (0, 255, 255), 2) | |
| + | cv2.putText(frame, f"{conf:.2f}", (int(bx) + 10, int(by) - 10), | |
| + | cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2, cv2.LINE_AA) | |
| + | elif candidates: | |
| + | for bx, by, bw, bh, conf in candidates[:3]: | |
| + | cv2.circle(frame, (int(bx), int(by)), 5, (0, 128, 255), 1) | |
| + | ||
| + | status = "NO MODEL" if (self.net is None and self.session is None) else ( | |
| + | "SWING" if press_contact else "READY" if ready else "TRACK" if eval_ball is not None else "DETECT" if is_real_ball else "WAIT" | |
| + | ) | |
| + | cv2.rectangle(frame, (16, full_h - 92), (500, full_h - 18), (0, 0, 0), -1) | |
| + | cv2.putText(frame, f"YOLO v2 {self.backend_name}{self.yolo_input} {status}", (30, full_h - 42), | |
| + | cv2.FONT_HERSHEY_SIMPLEX, 1.35, box_color, 4, cv2.LINE_AA) | |
| + | if self.net is None and self.session is None and self.model_error: | |
| + | cv2.putText(frame, self.model_error, (30, 44), | |
| + | cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2, cv2.LINE_AA) | |
| + | elif self.backend_name == "cv2" and self.ort_error: | |
| + | cv2.putText(frame, f"ORT fail: {self.ort_error[:90]}", (30, 44), | |
| + | cv2.FONT_HERSHEY_SIMPLEX, 0.58, (0, 128, 255), 2, cv2.LINE_AA) | |
| + | ||
| + | return frame, self.gcvdata |
| @@ -0,0 +1,4 @@ | ||
| + | ultralytics>=8.3 | |
| + | onnx>=1.16 | |
| + | onnxruntime>=1.19 | |
| + | rich>=13 |
| @@ -0,0 +1 @@ | ||
| + |
| @@ -0,0 +1,86 @@ | ||
| + | """One-shot detection snapshot. | |
| + | ||
| + | Opens the capture card directly (bypasses ZMQ), grabs one frame, runs the | |
| + | ball and PCI detectors, and saves an annotated PNG to logs/snapshot.png. | |
| + | Useful for eyeballing what the current HSV ranges actually pick up on | |
| + | whatever screen the Xbox is showing right now. | |
| + | """ | |
| + | from __future__ import annotations | |
| + | ||
| + | import sys | |
| + | import time | |
| + | from pathlib import Path | |
| + | ||
| + | import cv2 | |
| + | import numpy as np | |
| + | from rich.console import Console | |
| + | ||
| + | sys.path.insert(0, str(Path(__file__).resolve().parents[1])) | |
| + | from capture.ingest import open_capture # noqa: E402 | |
| + | from cv._common import load_config # noqa: E402 | |
| + | from cv.ball_tracker import detect_ball # noqa: E402 | |
| + | from cv.pci_tracker import detect_by_hsv as pci_detect # noqa: E402 | |
| + | ||
| + | console = Console() | |
| + | OUT_DIR = Path(__file__).resolve().parents[1] / "logs" | |
| + | ||
| + | def main() -> int: | |
| + | OUT_DIR.mkdir(parents=True, exist_ok=True) | |
| + | cfg = load_config() | |
| + | cap = open_capture(cfg) | |
| + | if cap is None or not cap.isOpened(): | |
| + | console.print("[red]Could not open capture card.[/red]") | |
| + | return 2 | |
| + | ||
| + | for _ in range(5): | |
| + | cap.read() | |
| + | time.sleep(0.02) | |
| + | ||
| + | ok, frame = cap.read() | |
| + | cap.release() | |
| + | if not ok or frame is None: | |
| + | console.print("[red]Capture returned no frame.[/red]") | |
| + | return 2 | |
| + | ||
| + | raw_out = OUT_DIR / "snapshot_raw.png" | |
| + | cv2.imwrite(str(raw_out), frame) | |
| + | console.print(f"[green]Saved raw -> {raw_out}[/green]") | |
| + | ||
| + | overlay = frame.copy() | |
| + | h, w = frame.shape[:2] | |
| + | plate_y = int(h * float(cfg["cv"].get("plate_y_frac", 0.72))) | |
| + | cv2.line(overlay, (0, plate_y), (w, plate_y), (255, 255, 255), 1, cv2.LINE_AA) | |
| + | ||
| + | ball = detect_ball(frame, cfg) | |
| + | if ball is not None: | |
| + | cv2.circle(overlay, (int(ball.x), int(ball.y)), max(int(ball.r), 4), (0, 255, 255), 2) | |
| + | cv2.putText( | |
| + | overlay, | |
| + | f"ball r={ball.r:.0f} s={ball.score:.2f}", | |
| + | (int(ball.x) + 10, int(ball.y) - 10), | |
| + | cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2, cv2.LINE_AA, | |
| + | ) | |
| + | console.print(f"[bold]ball:[/bold] x={ball.x:.0f} y={ball.y:.0f} r={ball.r:.1f} score={ball.score:.2f}") | |
| + | else: | |
| + | console.print("[yellow]ball: none[/yellow]") | |
| + | ||
| + | pci = pci_detect(frame, cfg["cv"]["pci"]) | |
| + | if pci is not None: | |
| + | cv2.circle(overlay, (int(pci["x"]), int(pci["y"])), int(pci["r"]), (0, 255, 0), 2) | |
| + | cv2.putText( | |
| + | overlay, | |
| + | f"pci r={pci['r']:.0f}", | |
| + | (int(pci["x"]) + 10, int(pci["y"]) + 20), | |
| + | cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2, cv2.LINE_AA, | |
| + | ) | |
| + | console.print(f"[bold]pci:[/bold] x={pci['x']:.0f} y={pci['y']:.0f} r={pci['r']:.1f} score={pci['score']:.2f}") | |
| + | else: | |
| + | console.print("[yellow]pci: none[/yellow]") | |
| + | ||
| + | out = OUT_DIR / "snapshot.png" | |
| + | cv2.imwrite(str(out), overlay) | |
| + | console.print(f"[green]Saved -> {out}[/green]") | |
| + | return 0 | |
| + | ||
| + | if __name__ == "__main__": | |
| + | sys.exit(main()) |
| @@ -0,0 +1,71 @@ | ||
| + | """Capture live frames for one-class MLB baseball YOLO training. | |
| + | ||
| + | Saves raw frames to datasets/mlb26_ball_yolo/images/all. Label these images | |
| + | with class 0 = ball, then run yolo_split_dataset.py. | |
| + | """ | |
| + | from __future__ import annotations | |
| + | ||
| + | import argparse | |
| + | import shutil | |
| + | import sys | |
| + | import time | |
| + | from pathlib import Path | |
| + | ||
| + | import cv2 | |
| + | from rich.console import Console | |
| + | ||
| + | sys.path.insert(0, str(Path(__file__).resolve().parents[1])) | |
| + | from capture.ingest import open_capture # noqa: E402 | |
| + | from cv._common import load_config # noqa: E402 | |
| + | ||
| + | ROOT = Path(__file__).resolve().parents[1] | |
| + | OUT_DIR = ROOT / "datasets" / "mlb26_ball_yolo" / "images" / "all" | |
| + | console = Console() | |
| + | ||
| + | def main() -> int: | |
| + | ap = argparse.ArgumentParser(description="Capture frames for YOLO ball labeling.") | |
| + | ap.add_argument("--duration", type=float, default=120.0, help="Capture duration in seconds.") | |
| + | ap.add_argument("--fps", type=float, default=12.0, help="Saved frame rate.") | |
| + | ap.add_argument("--clear", action="store_true", help="Clear images/all before capture.") | |
| + | ap.add_argument("--prefix", default="mlb26", help="Output filename prefix.") | |
| + | args = ap.parse_args() | |
| + | ||
| + | if args.clear and OUT_DIR.exists(): | |
| + | shutil.rmtree(OUT_DIR) | |
| + | OUT_DIR.mkdir(parents=True, exist_ok=True) | |
| + | ||
| + | cap = open_capture(load_config()) | |
| + | if cap is None or not cap.isOpened(): | |
| + | console.print("[red]Could not open capture card.[/red]") | |
| + | return 2 | |
| + | ||
| + | interval = 1.0 / max(args.fps, 0.1) | |
| + | deadline = time.perf_counter() + args.duration | |
| + | next_save = time.perf_counter() | |
| + | saved = 0 | |
| + | read_count = 0 | |
| + | ||
| + | console.print(f"[cyan]Capturing {args.duration:.0f}s at {args.fps:.1f} saved fps -> {OUT_DIR}[/cyan]") | |
| + | try: | |
| + | while time.perf_counter() < deadline: | |
| + | ok, frame = cap.read() | |
| + | if not ok or frame is None: | |
| + | continue | |
| + | read_count += 1 | |
| + | now = time.perf_counter() | |
| + | if now < next_save: | |
| + | continue | |
| + | out = OUT_DIR / f"{args.prefix}_{int(time.time() * 1000)}_{saved:05d}.jpg" | |
| + | cv2.imwrite(str(out), frame, [cv2.IMWRITE_JPEG_QUALITY, 95]) | |
| + | saved += 1 | |
| + | next_save = now + interval | |
| + | if saved % 50 == 0: | |
| + | console.print(f" saved {saved} frames") | |
| + | finally: | |
| + | cap.release() | |
| + | ||
| + | console.print(f"[green]Saved {saved} frames. Capture reads: {read_count}.[/green]") | |
| + | return 0 | |
| + | ||
| + | if __name__ == "__main__": | |
| + | raise SystemExit(main()) |
| @@ -0,0 +1,37 @@ | ||
| + | """Run the trained ball detector on one frame and save an annotated image.""" | |
| + | from __future__ import annotations | |
| + | ||
| + | import argparse | |
| + | from pathlib import Path | |
| + | ||
| + | import cv2 | |
| + | from rich.console import Console | |
| + | ||
| + | ROOT = Path(__file__).resolve().parents[1] | |
| + | MODEL = ROOT / "configs" / "models" / "mlb26_ball.pt" | |
| + | console = Console() | |
| + | ||
| + | def main() -> int: | |
| + | ap = argparse.ArgumentParser(description="Infer MLB ball model on one image.") | |
| + | ap.add_argument("image", type=Path) | |
| + | ap.add_argument("--model", type=Path, default=MODEL) | |
| + | ap.add_argument("--conf", type=float, default=0.25) | |
| + | ap.add_argument("--out", type=Path, default=ROOT / "logs" / "yolo_infer.png") | |
| + | args = ap.parse_args() | |
| + | ||
| + | from ultralytics import YOLO | |
| + | ||
| + | model = YOLO(str(args.model)) | |
| + | results = model.predict(source=str(args.image), conf=args.conf, imgsz=640, verbose=False) | |
| + | annotated = results[0].plot() | |
| + | args.out.parent.mkdir(parents=True, exist_ok=True) | |
| + | cv2.imwrite(str(args.out), annotated) | |
| + | console.print(f"[green]Saved annotated output -> {args.out}[/green]") | |
| + | for box in results[0].boxes: | |
| + | xyxy = [float(x) for x in box.xyxy[0]] | |
| + | conf = float(box.conf[0]) | |
| + | console.print(f"ball conf={conf:.3f} xyxy={xyxy}") | |
| + | return 0 | |
| + | ||
| + | if __name__ == "__main__": | |
| + | raise SystemExit(main()) |
| @@ -0,0 +1,178 @@ | ||
| + | """Tiny OpenCV YOLO labeler for class 0 = ball. | |
| + | ||
| + | Controls: | |
| + | drag left mouse: draw ball box, save, and go next | |
| + | s, Space, or Enter: save label and go next | |
| + | n: next without saving/changing | |
| + | b: previous | |
| + | c: clear current label | |
| + | e: save empty/no-ball label and go next | |
| + | q or Esc: quit | |
| + | """ | |
| + | from __future__ import annotations | |
| + | ||
| + | import argparse | |
| + | from pathlib import Path | |
| + | ||
| + | import cv2 | |
| + | from rich.console import Console | |
| + | ||
| + | ROOT = Path(__file__).resolve().parents[1] | |
| + | DATA = ROOT / "datasets" / "mlb26_ball_yolo" | |
| + | IMG_DIR = DATA / "images" / "all" | |
| + | LBL_DIR = DATA / "labels" / "all" | |
| + | console = Console() | |
| + | ||
| + | class LabelState: | |
| + | def __init__(self, image, scale: float): | |
| + | self.image = image | |
| + | self.scale = scale | |
| + | self.box = None | |
| + | self.drag_start = None | |
| + | self.drag_now = None | |
| + | self.auto_advance = False | |
| + | ||
| + | def set_box_from_display(self, x1, y1, x2, y2): | |
| + | ox1 = int(min(x1, x2) / self.scale) | |
| + | oy1 = int(min(y1, y2) / self.scale) | |
| + | ox2 = int(max(x1, x2) / self.scale) | |
| + | oy2 = int(max(y1, y2) / self.scale) | |
| + | h, w = self.image.shape[:2] | |
| + | ox1, ox2 = max(0, ox1), min(w - 1, ox2) | |
| + | oy1, oy2 = max(0, oy1), min(h - 1, oy2) | |
| + | if ox2 - ox1 >= 2 and oy2 - oy1 >= 2: | |
| + | self.box = (ox1, oy1, ox2, oy2) | |
| + | ||
| + | def yolo_line(box, w, h) -> str: | |
| + | x1, y1, x2, y2 = box | |
| + | xc = ((x1 + x2) / 2.0) / w | |
| + | yc = ((y1 + y2) / 2.0) / h | |
| + | bw = (x2 - x1) / w | |
| + | bh = (y2 - y1) / h | |
| + | return f"0 {xc:.6f} {yc:.6f} {bw:.6f} {bh:.6f}\n" | |
| + | ||
| + | def load_label(path: Path, w: int, h: int): | |
| + | if not path.exists() or not path.read_text(encoding="utf-8").strip(): | |
| + | return None | |
| + | parts = path.read_text(encoding="utf-8").split() | |
| + | if len(parts) < 5: | |
| + | return None | |
| + | _, xc, yc, bw, bh = parts[:5] | |
| + | xc, yc, bw, bh = map(float, (xc, yc, bw, bh)) | |
| + | x1 = int((xc - bw / 2.0) * w) | |
| + | y1 = int((yc - bh / 2.0) * h) | |
| + | x2 = int((xc + bw / 2.0) * w) | |
| + | y2 = int((yc + bh / 2.0) * h) | |
| + | return x1, y1, x2, y2 | |
| + | ||
| + | def draw(state: LabelState, title: str): | |
| + | img = state.image.copy() | |
| + | if state.box: | |
| + | x1, y1, x2, y2 = state.box | |
| + | cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 255), 2) | |
| + | disp = cv2.resize(img, None, fx=state.scale, fy=state.scale, interpolation=cv2.INTER_AREA) | |
| + | if state.drag_start and state.drag_now: | |
| + | cv2.rectangle(disp, state.drag_start, state.drag_now, (0, 255, 0), 2) | |
| + | cv2.putText(disp, title, (12, 28), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 0), 4, cv2.LINE_AA) | |
| + | cv2.putText(disp, title, (12, 28), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2, cv2.LINE_AA) | |
| + | help_text = "Draw tight box around baseball. Mouse-up saves + next. b=back c=clear e=empty q=quit" | |
| + | cv2.rectangle(disp, (8, disp.shape[0] - 42), (min(disp.shape[1] - 8, 980), disp.shape[0] - 8), (0, 0, 0), -1) | |
| + | cv2.putText(disp, help_text, (16, disp.shape[0] - 18), cv2.FONT_HERSHEY_SIMPLEX, 0.65, (255, 255, 255), 2, cv2.LINE_AA) | |
| + | return disp | |
| + | ||
| + | def main() -> int: | |
| + | ap = argparse.ArgumentParser(description="Label YOLO ball boxes.") | |
| + | ap.add_argument("--images", type=Path, default=IMG_DIR) | |
| + | ap.add_argument("--labels", type=Path, default=LBL_DIR) | |
| + | ap.add_argument("--scale", type=float, default=0.65) | |
| + | ap.add_argument("--from-start", action="store_true", help="start at first image instead of first unlabeled image") | |
| + | ap.add_argument("--start-index", type=int, default=0, help="1-based index to start from") | |
| + | ap.add_argument("--start-name", default="", help="exact image filename to start from") | |
| + | ap.add_argument("--manual", action="store_true", help="do not auto-save/advance on mouse release") | |
| + | args = ap.parse_args() | |
| + | ||
| + | args.labels.mkdir(parents=True, exist_ok=True) | |
| + | images = sorted([p for p in args.images.iterdir() if p.suffix.lower() in {".jpg", ".jpeg", ".png"}]) | |
| + | if not images: | |
| + | console.print(f"[red]No images found in {args.images}[/red]") | |
| + | return 2 | |
| + | ||
| + | idx = 0 | |
| + | if args.start_name: | |
| + | found = next((i for i, p in enumerate(images) if p.name == args.start_name), None) | |
| + | if found is None: | |
| + | console.print(f"[red]--start-name not found: {args.start_name}[/red]") | |
| + | return 2 | |
| + | idx = found | |
| + | elif args.start_index > 0: | |
| + | idx = min(len(images) - 1, max(0, args.start_index - 1)) | |
| + | elif not args.from_start: | |
| + | for i, image_path in enumerate(images): | |
| + | if not (args.labels / f"{image_path.stem}.txt").exists(): | |
| + | idx = i | |
| + | break | |
| + | win = "label ball" | |
| + | ||
| + | while 0 <= idx < len(images): | |
| + | path = images[idx] | |
| + | img = cv2.imread(str(path)) | |
| + | if img is None: | |
| + | idx += 1 | |
| + | continue | |
| + | h, w = img.shape[:2] | |
| + | label_path = args.labels / f"{path.stem}.txt" | |
| + | state = LabelState(img, args.scale) | |
| + | state.box = load_label(label_path, w, h) | |
| + | ||
| + | def on_mouse(event, x, y, flags, param): | |
| + | if event == cv2.EVENT_LBUTTONDOWN: | |
| + | state.drag_start = (x, y) | |
| + | state.drag_now = (x, y) | |
| + | elif event == cv2.EVENT_MOUSEMOVE and state.drag_start: | |
| + | state.drag_now = (x, y) | |
| + | elif event == cv2.EVENT_LBUTTONUP and state.drag_start: | |
| + | state.set_box_from_display(state.drag_start[0], state.drag_start[1], x, y) | |
| + | state.drag_start = None | |
| + | state.drag_now = None | |
| + | if state.box and not args.manual: | |
| + | state.auto_advance = True | |
| + | ||
| + | cv2.namedWindow(win, cv2.WINDOW_NORMAL) | |
| + | cv2.setMouseCallback(win, on_mouse) | |
| + | while True: | |
| + | title = f"{idx + 1}/{len(images)} {path.name} auto-save on mouse-up | s save | b back | c clear | q quit" | |
| + | cv2.imshow(win, draw(state, title)) | |
| + | key = cv2.waitKey(16) & 0xFF | |
| + | if state.auto_advance: | |
| + | label_path.write_text(yolo_line(state.box, w, h), encoding="utf-8") | |
| + | idx += 1 | |
| + | break | |
| + | if key in (ord("q"), 27): | |
| + | cv2.destroyAllWindows() | |
| + | return 0 | |
| + | if key == ord("c"): | |
| + | state.box = None | |
| + | if label_path.exists(): | |
| + | label_path.unlink() | |
| + | if key in (ord("s"), 13, 32): | |
| + | if state.box: | |
| + | label_path.write_text(yolo_line(state.box, w, h), encoding="utf-8") | |
| + | idx += 1 | |
| + | break | |
| + | if key == ord("e"): | |
| + | label_path.write_text("", encoding="utf-8") | |
| + | idx += 1 | |
| + | break | |
| + | if key == ord("n"): | |
| + | idx += 1 | |
| + | break | |
| + | if key == ord("b"): | |
| + | idx = max(0, idx - 1) | |
| + | break | |
| + | ||
| + | cv2.destroyAllWindows() | |
| + | console.print("[green]Labeling complete.[/green]") | |
| + | return 0 | |
| + | ||
| + | if __name__ == "__main__": | |
| + | raise SystemExit(main()) |
| @@ -0,0 +1,94 @@ | ||
| + | """Split labeled YOLO data into train/val/test folders.""" | |
| + | from __future__ import annotations | |
| + | ||
| + | import argparse | |
| + | import random | |
| + | import shutil | |
| + | from pathlib import Path | |
| + | ||
| + | from rich.console import Console | |
| + | ||
| + | ROOT = Path(__file__).resolve().parents[1] | |
| + | DATA = ROOT / "datasets" / "mlb26_ball_yolo" | |
| + | IMG_ALL = DATA / "images" / "all" | |
| + | LBL_ALL = DATA / "labels" / "all" | |
| + | console = Console() | |
| + | ||
| + | def reset_split_dirs() -> None: | |
| + | for kind in ("images", "labels"): | |
| + | for split in ("train", "val", "test"): | |
| + | d = DATA / kind / split | |
| + | if d.exists(): | |
| + | shutil.rmtree(d) | |
| + | d.mkdir(parents=True, exist_ok=True) | |
| + | ||
| + | def main() -> int: | |
| + | ap = argparse.ArgumentParser(description="Split YOLO ball dataset.") | |
| + | ap.add_argument("--val", type=float, default=0.18, help="Validation fraction.") | |
| + | ap.add_argument("--test", type=float, default=0.05, help="Test fraction.") | |
| + | ap.add_argument("--seed", type=int, default=26) | |
| + | ap.add_argument( | |
| + | "--max-neg-ratio", | |
| + | type=float, | |
| + | default=4.0, | |
| + | help="Maximum empty/no-ball images per positive image. Use 0 to keep all negatives.", | |
| + | ) | |
| + | args = ap.parse_args() | |
| + | ||
| + | if not IMG_ALL.exists() or not LBL_ALL.exists(): | |
| + | console.print(f"[red]Expected images at {IMG_ALL} and labels at {LBL_ALL}.[/red]") | |
| + | return 2 | |
| + | ||
| + | images = sorted([p for p in IMG_ALL.iterdir() if p.suffix.lower() in {".jpg", ".jpeg", ".png"}]) | |
| + | positives = [] | |
| + | negatives = [] | |
| + | for img in images: | |
| + | label = LBL_ALL / f"{img.stem}.txt" | |
| + | if label.exists(): | |
| + | if label.stat().st_size > 0: | |
| + | positives.append((img, label)) | |
| + | else: | |
| + | negatives.append((img, label)) | |
| + | ||
| + | if not positives: | |
| + | console.print("[red]No labeled image/label pairs found.[/red]") | |
| + | return 3 | |
| + | ||
| + | rng = random.Random(args.seed) | |
| + | rng.shuffle(positives) | |
| + | rng.shuffle(negatives) | |
| + | if args.max_neg_ratio and negatives: | |
| + | max_negs = int(round(len(positives) * args.max_neg_ratio)) | |
| + | negatives = negatives[:max_negs] | |
| + | ||
| + | def split_pairs(pairs): | |
| + | n = len(pairs) | |
| + | n_test = int(round(n * args.test)) | |
| + | n_val = int(round(n * args.val)) | |
| + | return { | |
| + | "test": pairs[:n_test], | |
| + | "val": pairs[n_test:n_test + n_val], | |
| + | "train": pairs[n_test + n_val:], | |
| + | } | |
| + | ||
| + | pos_buckets = split_pairs(positives) | |
| + | neg_buckets = split_pairs(negatives) | |
| + | buckets = { | |
| + | split: pos_buckets[split] + neg_buckets[split] | |
| + | for split in ("test", "val", "train") | |
| + | } | |
| + | ||
| + | reset_split_dirs() | |
| + | for split, split_pairs in buckets.items(): | |
| + | for img, label in split_pairs: | |
| + | shutil.copy2(img, DATA / "images" / split / img.name) | |
| + | shutil.copy2(label, DATA / "labels" / split / label.name) | |
| + | pos_count = sum(1 for _, label in split_pairs if label.stat().st_size > 0) | |
| + | neg_count = len(split_pairs) - pos_count | |
| + | console.print(f"[green]{split}: {len(split_pairs)} ({pos_count} ball, {neg_count} empty)[/green]") | |
| + | ||
| + | console.print(f"[bold green]Dataset ready at {DATA}[/bold green]") | |
| + | return 0 | |
| + | ||
| + | if __name__ == "__main__": | |
| + | raise SystemExit(main()) |
| @@ -0,0 +1,68 @@ | ||
| + | """Train and export a one-class MLB ball detector with Ultralytics YOLO.""" | |
| + | from __future__ import annotations | |
| + | ||
| + | import argparse | |
| + | import shutil | |
| + | from pathlib import Path | |
| + | ||
| + | from rich.console import Console | |
| + | ||
| + | ROOT = Path(__file__).resolve().parents[1] | |
| + | DATA_YAML = ROOT / "configs" / "ball_yolo.yaml" | |
| + | MODEL_DIR = ROOT / "configs" / "models" | |
| + | console = Console() | |
| + | ||
| + | def main() -> int: | |
| + | ap = argparse.ArgumentParser(description="Train/export MLB ball YOLO model.") | |
| + | ap.add_argument("--model", default="yolo11n.pt", help="Base model, e.g. yolo11n.pt or yolov8n.pt.") | |
| + | ap.add_argument("--epochs", type=int, default=80) | |
| + | ap.add_argument("--imgsz", type=int, default=640) | |
| + | ap.add_argument("--batch", type=float, default=None, help="Batch size. Omit for Ultralytics default.") | |
| + | ap.add_argument("--device", default=None, help="cuda device id, cpu, or omit for auto.") | |
| + | ap.add_argument("--name", default="mlb26_ball") | |
| + | ap.add_argument("--workers", type=int, default=0, help="Dataloader workers. Use 0 on Windows/network paths.") | |
| + | args = ap.parse_args() | |
| + | ||
| + | from ultralytics import YOLO | |
| + | ||
| + | model = YOLO(args.model) | |
| + | train_kwargs = { | |
| + | "data": str(DATA_YAML), | |
| + | "epochs": args.epochs, | |
| + | "imgsz": args.imgsz, | |
| + | "name": args.name, | |
| + | "project": str(ROOT / "runs" / "detect"), | |
| + | "single_cls": True, | |
| + | "patience": 25, | |
| + | "amp": False, | |
| + | "workers": args.workers, | |
| + | "mosaic": 0.0, | |
| + | "fliplr": 0.0, | |
| + | "erasing": 0.0, | |
| + | "hsv_h": 0.0, | |
| + | "hsv_s": 0.15, | |
| + | "hsv_v": 0.15, | |
| + | } | |
| + | if args.batch is not None: | |
| + | train_kwargs["batch"] = int(args.batch) if args.batch >= 1 else args.batch | |
| + | if args.device: | |
| + | train_kwargs["device"] = args.device | |
| + | ||
| + | results = model.train(**train_kwargs) | |
| + | best = Path(results.save_dir) / "weights" / "best.pt" | |
| + | if not best.exists(): | |
| + | console.print(f"[red]Training finished but best.pt was not found at {best}[/red]") | |
| + | return 2 | |
| + | ||
| + | trained = YOLO(str(best)) | |
| + | onnx_path = Path(trained.export(format="onnx", imgsz=args.imgsz, simplify=True, opset=12)) | |
| + | MODEL_DIR.mkdir(parents=True, exist_ok=True) | |
| + | final_onnx = MODEL_DIR / "mlb26_ball.onnx" | |
| + | shutil.copy2(onnx_path, final_onnx) | |
| + | shutil.copy2(best, MODEL_DIR / "mlb26_ball.pt") | |
| + | console.print(f"[bold green]Exported ONNX -> {final_onnx}[/bold green]") | |
| + | console.print(f"[green]Saved PyTorch weights -> {MODEL_DIR / 'mlb26_ball.pt'}[/green]") | |
| + | return 0 | |
| + | ||
| + | if __name__ == "__main__": | |
| + | raise SystemExit(main()) |