Zion Boggan
repos/Pitch Tracker CV/tools/yolo_split_dataset.py
zionboggan.com ↗
94 lines · python
History for this file →
1
"""Split labeled YOLO data into train/val/test folders."""
2
from __future__ import annotations
3
 
4
import argparse
5
import random
6
import shutil
7
from pathlib import Path
8
 
9
from rich.console import Console
10
 
11
ROOT = Path(__file__).resolve().parents[1]
12
DATA = ROOT / "datasets" / "mlb26_ball_yolo"
13
IMG_ALL = DATA / "images" / "all"
14
LBL_ALL = DATA / "labels" / "all"
15
console = Console()
16
 
17
def reset_split_dirs() -> None:
18
    for kind in ("images", "labels"):
19
        for split in ("train", "val", "test"):
20
            d = DATA / kind / split
21
            if d.exists():
22
                shutil.rmtree(d)
23
            d.mkdir(parents=True, exist_ok=True)
24
 
25
def main() -> int:
26
    ap = argparse.ArgumentParser(description="Split YOLO ball dataset.")
27
    ap.add_argument("--val", type=float, default=0.18, help="Validation fraction.")
28
    ap.add_argument("--test", type=float, default=0.05, help="Test fraction.")
29
    ap.add_argument("--seed", type=int, default=26)
30
    ap.add_argument(
31
        "--max-neg-ratio",
32
        type=float,
33
        default=4.0,
34
        help="Maximum empty/no-ball images per positive image. Use 0 to keep all negatives.",
35
    )
36
    args = ap.parse_args()
37
 
38
    if not IMG_ALL.exists() or not LBL_ALL.exists():
39
        console.print(f"[red]Expected images at {IMG_ALL} and labels at {LBL_ALL}.[/red]")
40
        return 2
41
 
42
    images = sorted([p for p in IMG_ALL.iterdir() if p.suffix.lower() in {".jpg", ".jpeg", ".png"}])
43
    positives = []
44
    negatives = []
45
    for img in images:
46
        label = LBL_ALL / f"{img.stem}.txt"
47
        if label.exists():
48
            if label.stat().st_size > 0:
49
                positives.append((img, label))
50
            else:
51
                negatives.append((img, label))
52
 
53
    if not positives:
54
        console.print("[red]No labeled image/label pairs found.[/red]")
55
        return 3
56
 
57
    rng = random.Random(args.seed)
58
    rng.shuffle(positives)
59
    rng.shuffle(negatives)
60
    if args.max_neg_ratio and negatives:
61
        max_negs = int(round(len(positives) * args.max_neg_ratio))
62
        negatives = negatives[:max_negs]
63
 
64
    def split_pairs(pairs):
65
        n = len(pairs)
66
        n_test = int(round(n * args.test))
67
        n_val = int(round(n * args.val))
68
        return {
69
            "test": pairs[:n_test],
70
            "val": pairs[n_test:n_test + n_val],
71
            "train": pairs[n_test + n_val:],
72
        }
73
 
74
    pos_buckets = split_pairs(positives)
75
    neg_buckets = split_pairs(negatives)
76
    buckets = {
77
        split: pos_buckets[split] + neg_buckets[split]
78
        for split in ("test", "val", "train")
79
    }
80
 
81
    reset_split_dirs()
82
    for split, split_pairs in buckets.items():
83
        for img, label in split_pairs:
84
            shutil.copy2(img, DATA / "images" / split / img.name)
85
            shutil.copy2(label, DATA / "labels" / split / label.name)
86
        pos_count = sum(1 for _, label in split_pairs if label.stat().st_size > 0)
87
        neg_count = len(split_pairs) - pos_count
88
        console.print(f"[green]{split}: {len(split_pairs)} ({pos_count} ball, {neg_count} empty)[/green]")
89
 
90
    console.print(f"[bold green]Dataset ready at {DATA}[/bold green]")
91
    return 0
92
 
93
if __name__ == "__main__":
94
    raise SystemExit(main())