| 1 | """YOLO/ONNX GCVWorker for MLB The Show 26. |
| 2 | |
| 3 | Use this after training/exporting configs/models/mlb26_ball.onnx. It keeps the |
| 4 | same 20-byte GCV packet contract as mlb26_gcv.py and mlb26_bridge.gpc. |
| 5 | """ |
| 6 | from __future__ import annotations |
| 7 | |
| 8 | import os |
| 9 | import struct |
| 10 | import sys |
| 11 | import time |
| 12 | import ctypes |
| 13 | from pathlib import Path |
| 14 | |
| 15 | import cv2 |
| 16 | import numpy as np |
| 17 | |
| 18 | def _env_roots() -> tuple[Path, ...]: |
| 19 | extra = os.environ.get("PITCH_TRACKER_CV_ROOT") |
| 20 | if not extra: |
| 21 | return () |
| 22 | return tuple(Path(p.strip()) for p in extra.split(os.pathsep) if p.strip()) |
| 23 | |
| 24 | |
| 25 | PROJECT_ROOTS = (Path(__file__).resolve().parents[1],) + _env_roots() |
| 26 | PREFERRED_ROOTS = PROJECT_ROOTS |
| 27 | MODEL_NAMES = ( |
| 28 | ("mlb26_ball_320.onnx", 320), |
| 29 | ("mlb26_ball_416.onnx", 416), |
| 30 | ("mlb26_ball_512.onnx", 512), |
| 31 | ("mlb26_ball.onnx", 640), |
| 32 | ) |
| 33 | CV2_MODEL_NAMES = ( |
| 34 | ("mlb26_ball_320.onnx", 320), |
| 35 | ("mlb26_ball_416.onnx", 416), |
| 36 | ("mlb26_ball_512.onnx", 512), |
| 37 | ("mlb26_ball.onnx", 640), |
| 38 | ) |
| 39 | |
| 40 | YOLO_INPUT = 320 |
| 41 | YOLO_CONF = 0.18 |
| 42 | YOLO_NMS = 0.45 |
| 43 | |
| 44 | DETECT_X_FRAC = (0.31, 0.66) |
| 45 | DETECT_Y_FRAC = (0.18, 0.84) |
| 46 | ACQUIRE_X_FRAC = (0.40, 0.60) |
| 47 | ACQUIRE_Y_FRAC = (0.20, 0.52) |
| 48 | ACQUIRE_MAX_DIAM_FRAC = 0.0120 |
| 49 | MAX_LOCK_STEP_FRAC = 0.090 |
| 50 | MAX_LOCK_AGE_MS = 180 |
| 51 | MIN_DOWNWARD_STEP_PX = -10.0 |
| 52 | |
| 53 | SZ_CX_FRAC = 0.492 |
| 54 | SZ_CY_FRAC = 0.518 |
| 55 | SZ_W_FRAC = 0.203 |
| 56 | SZ_H_FRAC = 0.555 |
| 57 | SWING_PAD_PX = 0 |
| 58 | SWING_MIN_DIAM_FRAC = 0.004 |
| 59 | SWING_HOLD_MS = 280 |
| 60 | SWING_COOLDOWN_MS = 500 |
| 61 | SWING_HIT_DIAM_FRAC = 0.0070 |
| 62 | SWING_MIN_GROWTH_PX_S = 3.0 |
| 63 | SWING_ETA_FIRE_MS = 180 |
| 64 | MIN_TRACK_FRAMES_FOR_SWING = 1 |
| 65 | PITCH_END_TIMEOUT_MS = 260 |
| 66 | |
| 67 | AIM_GAIN_X = 0.19 |
| 68 | AIM_GAIN_Y = 0.16 |
| 69 | AIM_DEADZONE_PX = 6.0 |
| 70 | AIM_MAX_STICK = 92.0 |
| 71 | BALL_MEMORY_MS = 130 |
| 72 | BALL_SWING_MEMORY_MS = 110 |
| 73 | INFER_EVERY_N_FRAMES = 2 |
| 74 | |
| 75 | DEBUG_FORCE_AIM = False |
| 76 | |
| 77 | CUDA_DLL_NAMES = ( |
| 78 | "zlibwapi.dll", |
| 79 | "cudart64_12.dll", |
| 80 | "cublas64_12.dll", |
| 81 | "cublasLt64_12.dll", |
| 82 | "cufft64_11.dll", |
| 83 | "cufftw64_11.dll", |
| 84 | "curand64_10.dll", |
| 85 | "cusparse64_12.dll", |
| 86 | "cusolver64_11.dll", |
| 87 | "cusolverMg64_11.dll", |
| 88 | "nvJitLink_120_0.dll", |
| 89 | "nvrtc64_120_0.dll", |
| 90 | "nvrtc-builtins64_126.dll", |
| 91 | "cudnn64_9.dll", |
| 92 | "cudnn_ops64_9.dll", |
| 93 | "cudnn_adv64_9.dll", |
| 94 | "cudnn_cnn64_9.dll", |
| 95 | "cudnn_graph64_9.dll", |
| 96 | "cudnn_heuristic64_9.dll", |
| 97 | "cudnn_engines_runtime_compiled64_9.dll", |
| 98 | "cudnn_engines_precompiled64_9.dll", |
| 99 | ) |
| 100 | |
| 101 | ORT_DLL_NAMES = () |
| 102 | |
| 103 | def _fix32(value: float) -> bytes: |
| 104 | v = int(round(value * 65536.0)) |
| 105 | v = max(-2**31, min(2**31 - 1, v)) |
| 106 | return struct.pack(">i", v) |
| 107 | |
| 108 | def _int16(value: int) -> bytes: |
| 109 | v = max(-32768, min(32767, int(value))) |
| 110 | return struct.pack(">h", v) |
| 111 | |
| 112 | class GCVWorker: |
| 113 | def __init__(self, width, height): |
| 114 | os.chdir(os.path.dirname(__file__)) |
| 115 | self.width = width |
| 116 | self.height = height |
| 117 | self.gcvdata = bytearray(20) |
| 118 | self.net = None |
| 119 | self.session = None |
| 120 | self.input_name = None |
| 121 | self.backend_name = "none" |
| 122 | self.model_path = None |
| 123 | self.yolo_input = YOLO_INPUT |
| 124 | self.model_error = "" |
| 125 | self.ort_error = "" |
| 126 | self.runtime_site = "" |
| 127 | self.runtime_dll_dirs = [] |
| 128 | self.preload_error = "" |
| 129 | self.last_swing_ts = 0 |
| 130 | self.swing_hold_until = 0 |
| 131 | self.last_ball = None |
| 132 | self.prev_ball = None |
| 133 | self.locked_ball = None |
| 134 | self.last_ball_vel = (0.0, 0.0) |
| 135 | self.pitch_active = False |
| 136 | self.pitch_swinged = False |
| 137 | self.track_frames = 0 |
| 138 | self.pitch_last_seen_ns = 0 |
| 139 | self.frame_id = 0 |
| 140 | self._load_model() |
| 141 | |
| 142 | def _write_debug(self, message: str) -> None: |
| 143 | lines = [ |
| 144 | f"time={time.strftime('%Y-%m-%d %H:%M:%S')}", |
| 145 | f"message={message}", |
| 146 | f"file={__file__}", |
| 147 | f"cwd={os.getcwd()}", |
| 148 | f"python={sys.version}", |
| 149 | f"executable={sys.executable}", |
| 150 | f"model_path={self.model_path}", |
| 151 | f"backend={self.backend_name}", |
| 152 | f"yolo_input={self.yolo_input}", |
| 153 | f"runtime_site={self.runtime_site}", |
| 154 | f"runtime_dll_dirs={self.runtime_dll_dirs}", |
| 155 | f"preload_error={self.preload_error}", |
| 156 | f"ort_error={self.ort_error}", |
| 157 | ] |
| 158 | for root in PROJECT_ROOTS: |
| 159 | try: |
| 160 | if root.exists(): |
| 161 | (root / "gcv_yolo_runtime_debug.txt").write_text("\n".join(lines), encoding="utf-8") |
| 162 | except Exception: |
| 163 | pass |
| 164 | |
| 165 | @staticmethod |
| 166 | def _runtime_site_candidates(): |
| 167 | seen = set() |
| 168 | for root in PREFERRED_ROOTS: |
| 169 | if str(root) in seen: |
| 170 | continue |
| 171 | seen.add(str(root)) |
| 172 | yield root / ".venv_yolo" / "Lib" / "site-packages" |
| 173 | yield root / ".venv" / "Lib" / "site-packages" |
| 174 | |
| 175 | @staticmethod |
| 176 | def _add_dll_dirs(dll_dirs): |
| 177 | existing = [] |
| 178 | for dll_dir in dll_dirs: |
| 179 | if dll_dir and dll_dir.exists() and dll_dir not in existing: |
| 180 | existing.append(dll_dir) |
| 181 | if existing: |
| 182 | os.environ["PATH"] = os.pathsep.join(str(d) for d in existing) + os.pathsep + os.environ.get("PATH", "") |
| 183 | if hasattr(os, "add_dll_directory"): |
| 184 | for dll_dir in existing: |
| 185 | try: |
| 186 | os.add_dll_directory(str(dll_dir)) |
| 187 | except Exception: |
| 188 | pass |
| 189 | return existing |
| 190 | |
| 191 | @staticmethod |
| 192 | def _preload_dlls(dll_dirs, names): |
| 193 | errors = [] |
| 194 | for name in names: |
| 195 | for dll_dir in dll_dirs: |
| 196 | dll_path = dll_dir / name |
| 197 | if not dll_path.exists(): |
| 198 | continue |
| 199 | try: |
| 200 | ctypes.WinDLL(str(dll_path)) |
| 201 | except Exception as exc: |
| 202 | errors.append(f"{name}:{str(exc)[:70]}") |
| 203 | break |
| 204 | return "; ".join(errors[:6]) |
| 205 | |
| 206 | def __del__(self): |
| 207 | try: |
| 208 | del self.gcvdata |
| 209 | except Exception: |
| 210 | pass |
| 211 | |
| 212 | def _load_model(self) -> None: |
| 213 | try: |
| 214 | for project_root in PREFERRED_ROOTS: |
| 215 | model_dir = project_root / "configs" / "models" |
| 216 | for model_name, input_size in MODEL_NAMES: |
| 217 | model_path = model_dir / model_name |
| 218 | if model_path.exists(): |
| 219 | self.model_path = model_path |
| 220 | self.yolo_input = input_size |
| 221 | break |
| 222 | if self.model_path is not None: |
| 223 | break |
| 224 | if self.model_path is None: |
| 225 | self.model_error = "missing mlb26 ball ONNX" |
| 226 | return |
| 227 | yolo_site = None |
| 228 | for candidate in self._runtime_site_candidates(): |
| 229 | if (candidate / "onnxruntime").exists(): |
| 230 | yolo_site = candidate |
| 231 | break |
| 232 | yolo_capi = yolo_site / "onnxruntime" / "capi" if yolo_site else Path() |
| 233 | torch_lib = yolo_site / "torch" / "lib" if yolo_site else Path() |
| 234 | runtime_dlls = Path() |
| 235 | for root in PREFERRED_ROOTS: |
| 236 | candidate = root / "runtime_dlls" |
| 237 | if candidate.exists(): |
| 238 | runtime_dlls = candidate |
| 239 | break |
| 240 | if yolo_site and str(yolo_site) not in sys.path: |
| 241 | sys.path.insert(0, str(yolo_site)) |
| 242 | self.runtime_site = str(yolo_site) if yolo_site else "" |
| 243 | dll_dirs = self._add_dll_dirs((runtime_dlls, torch_lib, yolo_capi)) |
| 244 | self.runtime_dll_dirs = [str(d) for d in dll_dirs] |
| 245 | preload_errors = [] |
| 246 | cuda_preload = self._preload_dlls(dll_dirs, CUDA_DLL_NAMES) |
| 247 | if cuda_preload: |
| 248 | preload_errors.append(cuda_preload) |
| 249 | ort_preload = self._preload_dlls(dll_dirs, ORT_DLL_NAMES) |
| 250 | if ort_preload: |
| 251 | preload_errors.append(ort_preload) |
| 252 | self.preload_error = " | ".join(preload_errors)[:240] |
| 253 | try: |
| 254 | import onnxruntime as ort |
| 255 | |
| 256 | providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] |
| 257 | self.session = ort.InferenceSession(str(self.model_path), providers=providers) |
| 258 | self.input_name = self.session.get_inputs()[0].name |
| 259 | active = self.session.get_providers() |
| 260 | self.backend_name = "ort" if "CUDAExecutionProvider" in active else "ortcpu" |
| 261 | self.model_error = "" |
| 262 | self._write_debug(f"loaded onnxruntime providers={active}") |
| 263 | return |
| 264 | except Exception as exc: |
| 265 | self.ort_error = str(exc)[:120] |
| 266 | self.session = None |
| 267 | self.input_name = None |
| 268 | for project_root in PREFERRED_ROOTS: |
| 269 | model_dir = project_root / "configs" / "models" |
| 270 | for model_name, input_size in CV2_MODEL_NAMES: |
| 271 | model_path = model_dir / model_name |
| 272 | if model_path.exists(): |
| 273 | self.model_path = model_path |
| 274 | self.yolo_input = input_size |
| 275 | break |
| 276 | if self.model_path is not None and self.yolo_input == input_size: |
| 277 | break |
| 278 | self.net = cv2.dnn.readNetFromONNX(str(self.model_path)) |
| 279 | self.net.setPreferableBackend(cv2.dnn.DNN_BACKEND_OPENCV) |
| 280 | self.net.setPreferableTarget(cv2.dnn.DNN_TARGET_CPU) |
| 281 | self.backend_name = "cv2" |
| 282 | self.model_error = "" |
| 283 | self._write_debug("loaded cv2 fallback") |
| 284 | except Exception as exc: |
| 285 | self.net = None |
| 286 | self.session = None |
| 287 | self.model_error = str(exc)[:80] |
| 288 | |
| 289 | @staticmethod |
| 290 | def _rows_from_output(output): |
| 291 | out = output[0] if isinstance(output, (tuple, list)) else output |
| 292 | out = np.asarray(out) |
| 293 | if out.ndim == 3: |
| 294 | out = out[0] |
| 295 | if out.ndim != 2: |
| 296 | return np.empty((0, 0), dtype=np.float32) |
| 297 | if out.shape[0] <= 85 and out.shape[0] < out.shape[1]: |
| 298 | out = out.T |
| 299 | return out |
| 300 | |
| 301 | def _detect_candidates(self, frame): |
| 302 | if self.net is None and self.session is None: |
| 303 | return [] |
| 304 | |
| 305 | h, w = frame.shape[:2] |
| 306 | x0 = max(0, int(w * DETECT_X_FRAC[0])) |
| 307 | x1 = min(w, int(w * DETECT_X_FRAC[1])) |
| 308 | y0 = max(0, int(h * DETECT_Y_FRAC[0])) |
| 309 | y1 = min(h, int(h * DETECT_Y_FRAC[1])) |
| 310 | if x1 <= x0 or y1 <= y0: |
| 311 | return [] |
| 312 | |
| 313 | roi = frame[y0:y1, x0:x1] |
| 314 | roi_h, roi_w = roi.shape[:2] |
| 315 | blob = cv2.dnn.blobFromImage( |
| 316 | roi, |
| 317 | scalefactor=1.0 / 255.0, |
| 318 | size=(self.yolo_input, self.yolo_input), |
| 319 | mean=(0, 0, 0), |
| 320 | swapRB=True, |
| 321 | crop=False, |
| 322 | ) |
| 323 | if self.session is not None: |
| 324 | output = self.session.run(None, {self.input_name: blob}) |
| 325 | else: |
| 326 | self.net.setInput(blob) |
| 327 | output = self.net.forward() |
| 328 | rows = self._rows_from_output(output) |
| 329 | |
| 330 | boxes = [] |
| 331 | scores = [] |
| 332 | for row in rows: |
| 333 | if row.shape[0] < 5: |
| 334 | continue |
| 335 | if row.shape[0] == 5: |
| 336 | score = float(row[4]) |
| 337 | elif row.shape[0] == 6: |
| 338 | score = float(row[4]) if row[5] <= 1.0 else float(row[5]) |
| 339 | else: |
| 340 | score = float(np.max(row[4:])) |
| 341 | if score < YOLO_CONF: |
| 342 | continue |
| 343 | cx, cy, bw, bh = [float(v) for v in row[:4]] |
| 344 | aspect = bw / max(1.0, bh) |
| 345 | if aspect < 0.42 or aspect > 2.35: |
| 346 | continue |
| 347 | left = (cx - bw / 2.0) * roi_w / self.yolo_input |
| 348 | top = (cy - bh / 2.0) * roi_h / self.yolo_input |
| 349 | ww = bw * roi_w / self.yolo_input |
| 350 | hh = bh * roi_h / self.yolo_input |
| 351 | if ww < 3 or hh < 3: |
| 352 | continue |
| 353 | if ww > roi_w * 0.16 or hh > roi_h * 0.16: |
| 354 | continue |
| 355 | boxes.append([int(left), int(top), int(ww), int(hh)]) |
| 356 | scores.append(score) |
| 357 | |
| 358 | if not boxes: |
| 359 | return [] |
| 360 | keep = cv2.dnn.NMSBoxes(boxes, scores, YOLO_CONF, YOLO_NMS) |
| 361 | if len(keep) == 0: |
| 362 | return [] |
| 363 | candidates = [] |
| 364 | for idx in np.array(keep).reshape(-1): |
| 365 | bx, by, bw, bh = boxes[int(idx)] |
| 366 | conf = scores[int(idx)] |
| 367 | candidates.append(( |
| 368 | float(x0 + bx + bw / 2.0), |
| 369 | float(y0 + by + bh / 2.0), |
| 370 | float(bw), |
| 371 | float(bh), |
| 372 | float(conf), |
| 373 | )) |
| 374 | return candidates |
| 375 | |
| 376 | def _choose_ball(self, candidates, frame_w, frame_h, ts_ns): |
| 377 | if not candidates: |
| 378 | if self.last_ball and (ts_ns - self.last_ball[5]) > MAX_LOCK_AGE_MS * 1_000_000: |
| 379 | self.locked_ball = None |
| 380 | return None |
| 381 | |
| 382 | acquire_left = frame_w * ACQUIRE_X_FRAC[0] |
| 383 | acquire_right = frame_w * ACQUIRE_X_FRAC[1] |
| 384 | acquire_top = frame_h * ACQUIRE_Y_FRAC[0] |
| 385 | acquire_bottom = frame_h * ACQUIRE_Y_FRAC[1] |
| 386 | max_step = frame_w * MAX_LOCK_STEP_FRAC |
| 387 | if self.locked_ball is not None: |
| 388 | lx, ly, lw, lh, lconf, lts = self.locked_ball |
| 389 | age_ms = (ts_ns - lts) / 1_000_000 |
| 390 | if age_ms <= MAX_LOCK_AGE_MS: |
| 391 | scored = [] |
| 392 | for cand in candidates: |
| 393 | cx, cy, cw, ch, conf = cand |
| 394 | dx = cx - lx |
| 395 | dy = cy - ly |
| 396 | dist = (dx * dx + dy * dy) ** 0.5 |
| 397 | size_ratio = max(cw, ch) / max(1.0, max(lw, lh)) |
| 398 | if dist <= max_step and dy >= MIN_DOWNWARD_STEP_PX and 0.35 <= size_ratio <= 3.25: |
| 399 | scored.append((dist - conf * 35.0, cand)) |
| 400 | if scored: |
| 401 | chosen = min(scored, key=lambda item: item[0])[1] |
| 402 | self.locked_ball = (*chosen, ts_ns) |
| 403 | return chosen |
| 404 | |
| 405 | acquire = [ |
| 406 | c for c in candidates |
| 407 | if acquire_left <= c[0] <= acquire_right and acquire_top <= c[1] <= acquire_bottom |
| 408 | ] |
| 409 | if not acquire: |
| 410 | loose = [ |
| 411 | c for c in candidates |
| 412 | if c[1] <= frame_h * (ACQUIRE_Y_FRAC[1] + 0.10) and max(c[2], c[3]) <= frame_w * 0.020 |
| 413 | ] |
| 414 | if not loose: |
| 415 | return None |
| 416 | chosen = max(loose, key=lambda c: c[4]) |
| 417 | self.locked_ball = (*chosen, ts_ns) |
| 418 | return chosen |
| 419 | chosen = max(acquire, key=lambda c: c[4]) |
| 420 | self.locked_ball = (*chosen, ts_ns) |
| 421 | return chosen |
| 422 | |
| 423 | @staticmethod |
| 424 | def _segment_hits_box(prev, cur, left, right, top, bottom): |
| 425 | if prev is None or cur is None: |
| 426 | return False |
| 427 | px, py = prev[:2] |
| 428 | cx, cy = cur[:2] |
| 429 | if cy <= py: |
| 430 | return False |
| 431 | if left <= px <= right and top <= py <= bottom: |
| 432 | return True |
| 433 | if left <= cx <= right and top <= cy <= bottom: |
| 434 | return True |
| 435 | if cy < top or py > bottom: |
| 436 | return False |
| 437 | for y in (top, bottom): |
| 438 | if py <= y <= cy: |
| 439 | t = (y - py) / max(1.0, cy - py) |
| 440 | x = px + (cx - px) * t |
| 441 | if left <= x <= right: |
| 442 | return True |
| 443 | return False |
| 444 | |
| 445 | @staticmethod |
| 446 | def _ball_in_zone(ball, sz_left, sz_right, sz_top, sz_bottom): |
| 447 | bx, by = ball[0], ball[1] |
| 448 | return sz_left <= bx <= sz_right and sz_top <= by <= sz_bottom |
| 449 | |
| 450 | def _compute_swing_ready(self, prev_ball, cur_ball, sz_left, sz_right, sz_top, sz_bottom, full_w): |
| 451 | ball_diam = max(cur_ball[2], cur_ball[3]) |
| 452 | min_track_diam = max(6.0, full_w * SWING_MIN_DIAM_FRAC) |
| 453 | hit_diam = max(min_track_diam + 2.0, full_w * SWING_HIT_DIAM_FRAC) |
| 454 | if ball_diam < min_track_diam: |
| 455 | return False, 0, 0 |
| 456 | if not self._ball_in_zone(cur_ball, sz_left - SWING_PAD_PX, sz_right + SWING_PAD_PX, sz_top - SWING_PAD_PX, sz_bottom + SWING_PAD_PX): |
| 457 | return False, 0, 0 |
| 458 | |
| 459 | close_by_size = ball_diam >= hit_diam |
| 460 | eta_ms = 0 |
| 461 | pred_good = 0 |
| 462 | if prev_ball is not None: |
| 463 | prev_diam = max(prev_ball[2], prev_ball[3]) |
| 464 | dt_s = max(1e-6, (cur_ball[5] - prev_ball[5]) / 1e9) |
| 465 | growth_px_s = (ball_diam - prev_diam) / dt_s |
| 466 | if growth_px_s >= SWING_MIN_GROWTH_PX_S and ball_diam < hit_diam: |
| 467 | eta = (hit_diam - ball_diam) * 1000.0 / max(growth_px_s, 1.0) |
| 468 | if 0.0 <= eta <= 32767.0: |
| 469 | eta_ms = int(eta) |
| 470 | pred_good = 1 |
| 471 | close_soon = 0 < eta_ms <= SWING_ETA_FIRE_MS |
| 472 | ready = close_by_size or close_soon |
| 473 | if close_by_size: |
| 474 | pred_good = 1 |
| 475 | return ready, eta_ms, pred_good |
| 476 | |
| 477 | @staticmethod |
| 478 | def _in_acquire_window(ball, full_w, full_h): |
| 479 | bx, by, bw, bh = ball[0], ball[1], ball[2], ball[3] |
| 480 | max_d = full_w * ACQUIRE_MAX_DIAM_FRAC |
| 481 | return ( |
| 482 | full_w * ACQUIRE_X_FRAC[0] <= bx <= full_w * ACQUIRE_X_FRAC[1] |
| 483 | and full_h * ACQUIRE_Y_FRAC[0] <= by <= full_h * ACQUIRE_Y_FRAC[1] |
| 484 | and max(bw, bh) <= max_d |
| 485 | ) |
| 486 | |
| 487 | def process(self, frame): |
| 488 | ts_ns = time.time_ns() |
| 489 | full_h, full_w = frame.shape[:2] |
| 490 | self.frame_id += 1 |
| 491 | do_infer = (self.last_ball is None) or (self.frame_id % max(1, INFER_EVERY_N_FRAMES) == 0) |
| 492 | if do_infer: |
| 493 | candidates = self._detect_candidates(frame) |
| 494 | ball = self._choose_ball(candidates, full_w, full_h, ts_ns) |
| 495 | else: |
| 496 | candidates = [] |
| 497 | ball = None |
| 498 | |
| 499 | sz_cx = full_w * SZ_CX_FRAC |
| 500 | sz_cy = full_h * SZ_CY_FRAC |
| 501 | sz_w = full_w * SZ_W_FRAC |
| 502 | sz_h = full_h * SZ_H_FRAC |
| 503 | sz_left = sz_cx - sz_w / 2 |
| 504 | sz_right = sz_cx + sz_w / 2 |
| 505 | sz_top = sz_cy - sz_h / 2 |
| 506 | sz_bottom = sz_cy + sz_h / 2 |
| 507 | |
| 508 | aim_x = 0.0 |
| 509 | aim_y = 0.0 |
| 510 | press_contact = 0 |
| 511 | press_power = 0 |
| 512 | eta_ms = 0 |
| 513 | pred_good = 0 |
| 514 | ready = False |
| 515 | eval_ball = None |
| 516 | is_real_ball = ball is not None |
| 517 | |
| 518 | if ball is not None: |
| 519 | bx, by, bw, bh, conf = ball |
| 520 | if self.last_ball is not None: |
| 521 | px, py, pw, ph, pconf, pts_ns = self.last_ball |
| 522 | dt = (ts_ns - pts_ns) / 1e9 |
| 523 | if 0.0 < dt <= BALL_MEMORY_MS / 1000.0: |
| 524 | self.last_ball_vel = ((bx - px) / dt, (by - py) / dt) |
| 525 | if self.pitch_active: |
| 526 | self.track_frames += 1 |
| 527 | elif ( |
| 528 | self._in_acquire_window((bx, by, bw, bh, conf, ts_ns), full_w, full_h) |
| 529 | or (by <= (sz_top + 0.62 * sz_h) and max(bw, bh) <= full_w * 0.018) |
| 530 | ): |
| 531 | self.pitch_active = True |
| 532 | self.pitch_swinged = False |
| 533 | self.track_frames = 1 |
| 534 | else: |
| 535 | self.track_frames = 0 |
| 536 | self.pitch_last_seen_ns = ts_ns |
| 537 | if self.pitch_active and by > (sz_bottom + 0.25 * sz_h): |
| 538 | self.pitch_active = False |
| 539 | self.pitch_swinged = False |
| 540 | self.track_frames = 0 |
| 541 | ready, eta_local, pred_local = self._compute_swing_ready( |
| 542 | self.last_ball, |
| 543 | (bx, by, bw, bh, conf, ts_ns), |
| 544 | sz_left, |
| 545 | sz_right, |
| 546 | sz_top, |
| 547 | sz_bottom, |
| 548 | full_w, |
| 549 | ) |
| 550 | if eta_local > 0: |
| 551 | eta_ms = eta_local |
| 552 | pred_good = max(pred_good, pred_local) |
| 553 | dx = bx - sz_cx |
| 554 | dy = by - sz_cy |
| 555 | if abs(dx) > AIM_DEADZONE_PX: |
| 556 | aim_x = max(-AIM_MAX_STICK, min(AIM_MAX_STICK, AIM_GAIN_X * dx)) |
| 557 | if abs(dy) > AIM_DEADZONE_PX: |
| 558 | aim_y = max(-AIM_MAX_STICK, min(AIM_MAX_STICK, AIM_GAIN_Y * dy)) |
| 559 | cooled = (ts_ns - self.last_swing_ts) > SWING_COOLDOWN_MS * 1_000_000 |
| 560 | if ( |
| 561 | ready |
| 562 | and cooled |
| 563 | and is_real_ball |
| 564 | and (not self.pitch_swinged or (ts_ns - self.last_swing_ts) > 900 * 1_000_000) |
| 565 | and self.track_frames >= MIN_TRACK_FRAMES_FOR_SWING |
| 566 | ): |
| 567 | self.last_swing_ts = ts_ns |
| 568 | self.swing_hold_until = ts_ns + SWING_HOLD_MS * 1_000_000 |
| 569 | self.pitch_swinged = True |
| 570 | self.prev_ball = self.last_ball |
| 571 | self.last_ball = (bx, by, bw, bh, conf, ts_ns) |
| 572 | eval_ball = self.last_ball |
| 573 | elif self.last_ball is not None: |
| 574 | if self.pitch_active and self.pitch_last_seen_ns > 0: |
| 575 | if (ts_ns - self.pitch_last_seen_ns) > PITCH_END_TIMEOUT_MS * 1_000_000: |
| 576 | self.pitch_active = False |
| 577 | self.pitch_swinged = False |
| 578 | self.track_frames = 0 |
| 579 | bx, by, bw, bh, conf, last_ts = self.last_ball |
| 580 | age_ns = ts_ns - last_ts |
| 581 | if age_ns <= BALL_MEMORY_MS * 1_000_000: |
| 582 | age_s = age_ns / 1e9 |
| 583 | vx, vy = self.last_ball_vel |
| 584 | bx = bx + vx * age_s |
| 585 | by = by + vy * age_s |
| 586 | dx = bx - sz_cx |
| 587 | dy = by - sz_cy |
| 588 | if abs(dx) > AIM_DEADZONE_PX: |
| 589 | aim_x = max(-AIM_MAX_STICK, min(AIM_MAX_STICK, AIM_GAIN_X * dx)) |
| 590 | if abs(dy) > AIM_DEADZONE_PX: |
| 591 | aim_y = max(-AIM_MAX_STICK, min(AIM_MAX_STICK, AIM_GAIN_Y * dy)) |
| 592 | eval_ball = (bx, by, bw, bh, conf, ts_ns) |
| 593 | |
| 594 | if ts_ns < self.swing_hold_until: |
| 595 | press_contact = 1 |
| 596 | |
| 597 | if DEBUG_FORCE_AIM: |
| 598 | aim_x = 25.0 |
| 599 | |
| 600 | armed = 1 |
| 601 | in_flight = 1 if eval_ball is not None else 0 |
| 602 | debug_flags = (2 if is_real_ball else 0) | (4 if pred_good else 0) | (8 if self.pitch_active else 0) |
| 603 | |
| 604 | self.gcvdata[0:4] = _fix32(aim_x) |
| 605 | self.gcvdata[4:8] = _fix32(aim_y) |
| 606 | self.gcvdata[8:10] = _int16(armed) |
| 607 | self.gcvdata[10:12] = _int16(in_flight) |
| 608 | self.gcvdata[12:14] = _int16(press_contact) |
| 609 | self.gcvdata[14:16] = _int16(press_power) |
| 610 | self.gcvdata[16:18] = _int16(eta_ms) |
| 611 | self.gcvdata[18:20] = _int16(debug_flags) |
| 612 | |
| 613 | box_color = (0, 255, 255) if press_contact else ((0, 255, 0) if ready else (190, 190, 190)) |
| 614 | cv2.rectangle(frame, (int(sz_left), int(sz_top)), (int(sz_right), int(sz_bottom)), box_color, 6) |
| 615 | dx0 = int(full_w * DETECT_X_FRAC[0]) |
| 616 | dx1 = int(full_w * DETECT_X_FRAC[1]) |
| 617 | dy0 = int(full_h * DETECT_Y_FRAC[0]) |
| 618 | dy1 = int(full_h * DETECT_Y_FRAC[1]) |
| 619 | cv2.rectangle(frame, (dx0, dy0), (dx1, dy1), (80, 80, 80), 1) |
| 620 | if eval_ball is not None: |
| 621 | bx, by, bw, bh, conf = eval_ball[:5] |
| 622 | cv2.circle(frame, (int(bx), int(by)), max(8, int(max(bw, bh) / 2)), (0, 255, 255), 2) |
| 623 | cv2.putText(frame, f"{conf:.2f}", (int(bx) + 10, int(by) - 10), |
| 624 | cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2, cv2.LINE_AA) |
| 625 | elif candidates: |
| 626 | for bx, by, bw, bh, conf in candidates[:3]: |
| 627 | cv2.circle(frame, (int(bx), int(by)), 5, (0, 128, 255), 1) |
| 628 | |
| 629 | status = "NO MODEL" if (self.net is None and self.session is None) else ( |
| 630 | "SWING" if press_contact else "READY" if ready else "TRACK" if eval_ball is not None else "DETECT" if is_real_ball else "WAIT" |
| 631 | ) |
| 632 | cv2.rectangle(frame, (16, full_h - 92), (500, full_h - 18), (0, 0, 0), -1) |
| 633 | cv2.putText(frame, f"YOLO v2 {self.backend_name}{self.yolo_input} {status}", (30, full_h - 42), |
| 634 | cv2.FONT_HERSHEY_SIMPLEX, 1.35, box_color, 4, cv2.LINE_AA) |
| 635 | if self.net is None and self.session is None and self.model_error: |
| 636 | cv2.putText(frame, self.model_error, (30, 44), |
| 637 | cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2, cv2.LINE_AA) |
| 638 | elif self.backend_name == "cv2" and self.ort_error: |
| 639 | cv2.putText(frame, f"ORT fail: {self.ort_error[:90]}", (30, 44), |
| 640 | cv2.FONT_HERSHEY_SIMPLEX, 0.58, (0, 128, 255), 2, cv2.LINE_AA) |
| 641 | |
| 642 | return frame, self.gcvdata |