File size: 1,422 Bytes
3c27f51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from pathlib import Path
from typing import List, Optional, Tuple
import torch

DEMO_ROOT = "data"

def list_demo_tasks(root: str = DEMO_ROOT) -> List[str]:
    r = Path(root)
    if not r.exists():
        return []
    return sorted([p.name for p in r.iterdir() if p.is_dir()])

def list_demo_scenarios(task: str, root: str = DEMO_ROOT) -> List[str]:
    base = Path(root) / task
    if not base.exists():
        return []
    return sorted([p.name for p in base.iterdir() if p.is_dir()])

def _find_dataset_file(scenario_dir: Path) -> Optional[Path]:
    preferred = ["train_data.pt", "data.pt", "dataset.pt"]
    for name in preferred:
        cand = scenario_dir / name
        if cand.exists():
            return cand
    for ext in ("*.pt", "*.p"):
        files = list(scenario_dir.glob(ext))
        if files:
            return files[0]
    return None

def list_demo_dataset_files(task: str, root: str = DEMO_ROOT) -> List[str]:
    out = []
    base = Path(root) / task
    if not base.exists():
        return out
    for scen in list_demo_scenarios(task, root):
        scen_dir = base / scen
        f = _find_dataset_file(scen_dir)
        if f is not None:
            out.append(str(f))
    return out

def load_pt_dataset(path: str) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
    obj = torch.load(path, map_location="cpu")
    ch = obj["channels"]
    y = obj.get("labels", None)
    return ch, y