from pathlib import Path from typing import List, Optional, Tuple import torch DEMO_ROOT = "data" def list_demo_tasks(root: str = DEMO_ROOT) -> List[str]: r = Path(root) if not r.exists(): return [] return sorted([p.name for p in r.iterdir() if p.is_dir()]) def list_demo_scenarios(task: str, root: str = DEMO_ROOT) -> List[str]: base = Path(root) / task if not base.exists(): return [] return sorted([p.name for p in base.iterdir() if p.is_dir()]) def _find_dataset_file(scenario_dir: Path) -> Optional[Path]: preferred = ["train_data.pt", "data.pt", "dataset.pt"] for name in preferred: cand = scenario_dir / name if cand.exists(): return cand for ext in ("*.pt", "*.p"): files = list(scenario_dir.glob(ext)) if files: return files[0] return None def list_demo_dataset_files(task: str, root: str = DEMO_ROOT) -> List[str]: out = [] base = Path(root) / task if not base.exists(): return out for scen in list_demo_scenarios(task, root): scen_dir = base / scen f = _find_dataset_file(scen_dir) if f is not None: out.append(str(f)) return out def load_pt_dataset(path: str) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: obj = torch.load(path, map_location="cpu") ch = obj["channels"] y = obj.get("labels", None) return ch, y