| 1 | """Split labeled YOLO data into train/val/test folders.""" |
| 2 | from __future__ import annotations |
| 3 | |
| 4 | import argparse |
| 5 | import random |
| 6 | import shutil |
| 7 | from pathlib import Path |
| 8 | |
| 9 | from rich.console import Console |
| 10 | |
| 11 | ROOT = Path(__file__).resolve().parents[1] |
| 12 | DATA = ROOT / "datasets" / "mlb26_ball_yolo" |
| 13 | IMG_ALL = DATA / "images" / "all" |
| 14 | LBL_ALL = DATA / "labels" / "all" |
| 15 | console = Console() |
| 16 | |
| 17 | def reset_split_dirs() -> None: |
| 18 | for kind in ("images", "labels"): |
| 19 | for split in ("train", "val", "test"): |
| 20 | d = DATA / kind / split |
| 21 | if d.exists(): |
| 22 | shutil.rmtree(d) |
| 23 | d.mkdir(parents=True, exist_ok=True) |
| 24 | |
| 25 | def main() -> int: |
| 26 | ap = argparse.ArgumentParser(description="Split YOLO ball dataset.") |
| 27 | ap.add_argument("--val", type=float, default=0.18, help="Validation fraction.") |
| 28 | ap.add_argument("--test", type=float, default=0.05, help="Test fraction.") |
| 29 | ap.add_argument("--seed", type=int, default=26) |
| 30 | ap.add_argument( |
| 31 | "--max-neg-ratio", |
| 32 | type=float, |
| 33 | default=4.0, |
| 34 | help="Maximum empty/no-ball images per positive image. Use 0 to keep all negatives.", |
| 35 | ) |
| 36 | args = ap.parse_args() |
| 37 | |
| 38 | if not IMG_ALL.exists() or not LBL_ALL.exists(): |
| 39 | console.print(f"[red]Expected images at {IMG_ALL} and labels at {LBL_ALL}.[/red]") |
| 40 | return 2 |
| 41 | |
| 42 | images = sorted([p for p in IMG_ALL.iterdir() if p.suffix.lower() in {".jpg", ".jpeg", ".png"}]) |
| 43 | positives = [] |
| 44 | negatives = [] |
| 45 | for img in images: |
| 46 | label = LBL_ALL / f"{img.stem}.txt" |
| 47 | if label.exists(): |
| 48 | if label.stat().st_size > 0: |
| 49 | positives.append((img, label)) |
| 50 | else: |
| 51 | negatives.append((img, label)) |
| 52 | |
| 53 | if not positives: |
| 54 | console.print("[red]No labeled image/label pairs found.[/red]") |
| 55 | return 3 |
| 56 | |
| 57 | rng = random.Random(args.seed) |
| 58 | rng.shuffle(positives) |
| 59 | rng.shuffle(negatives) |
| 60 | if args.max_neg_ratio and negatives: |
| 61 | max_negs = int(round(len(positives) * args.max_neg_ratio)) |
| 62 | negatives = negatives[:max_negs] |
| 63 | |
| 64 | def split_pairs(pairs): |
| 65 | n = len(pairs) |
| 66 | n_test = int(round(n * args.test)) |
| 67 | n_val = int(round(n * args.val)) |
| 68 | return { |
| 69 | "test": pairs[:n_test], |
| 70 | "val": pairs[n_test:n_test + n_val], |
| 71 | "train": pairs[n_test + n_val:], |
| 72 | } |
| 73 | |
| 74 | pos_buckets = split_pairs(positives) |
| 75 | neg_buckets = split_pairs(negatives) |
| 76 | buckets = { |
| 77 | split: pos_buckets[split] + neg_buckets[split] |
| 78 | for split in ("test", "val", "train") |
| 79 | } |
| 80 | |
| 81 | reset_split_dirs() |
| 82 | for split, split_pairs in buckets.items(): |
| 83 | for img, label in split_pairs: |
| 84 | shutil.copy2(img, DATA / "images" / split / img.name) |
| 85 | shutil.copy2(label, DATA / "labels" / split / label.name) |
| 86 | pos_count = sum(1 for _, label in split_pairs if label.stat().st_size > 0) |
| 87 | neg_count = len(split_pairs) - pos_count |
| 88 | console.print(f"[green]{split}: {len(split_pairs)} ({pos_count} ball, {neg_count} empty)[/green]") |
| 89 | |
| 90 | console.print(f"[bold green]Dataset ready at {DATA}[/bold green]") |
| 91 | return 0 |
| 92 | |
| 93 | if __name__ == "__main__": |
| 94 | raise SystemExit(main()) |