wi-lab's picture
Create io_demo.py
3c27f51
from pathlib import Path
from typing import List, Optional, Tuple
import torch
DEMO_ROOT = "data"
def list_demo_tasks(root: str = DEMO_ROOT) -> List[str]:
r = Path(root)
if not r.exists():
return []
return sorted([p.name for p in r.iterdir() if p.is_dir()])
def list_demo_scenarios(task: str, root: str = DEMO_ROOT) -> List[str]:
base = Path(root) / task
if not base.exists():
return []
return sorted([p.name for p in base.iterdir() if p.is_dir()])
def _find_dataset_file(scenario_dir: Path) -> Optional[Path]:
preferred = ["train_data.pt", "data.pt", "dataset.pt"]
for name in preferred:
cand = scenario_dir / name
if cand.exists():
return cand
for ext in ("*.pt", "*.p"):
files = list(scenario_dir.glob(ext))
if files:
return files[0]
return None
def list_demo_dataset_files(task: str, root: str = DEMO_ROOT) -> List[str]:
out = []
base = Path(root) / task
if not base.exists():
return out
for scen in list_demo_scenarios(task, root):
scen_dir = base / scen
f = _find_dataset_file(scen_dir)
if f is not None:
out.append(str(f))
return out
def load_pt_dataset(path: str) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
obj = torch.load(path, map_location="cpu")
ch = obj["channels"]
y = obj.get("labels", None)
return ch, y