Spaces:
Sleeping
Sleeping
| from pathlib import Path | |
| from typing import List, Optional, Tuple | |
| import torch | |
| DEMO_ROOT = "data" | |
| def list_demo_tasks(root: str = DEMO_ROOT) -> List[str]: | |
| r = Path(root) | |
| if not r.exists(): | |
| return [] | |
| return sorted([p.name for p in r.iterdir() if p.is_dir()]) | |
| def list_demo_scenarios(task: str, root: str = DEMO_ROOT) -> List[str]: | |
| base = Path(root) / task | |
| if not base.exists(): | |
| return [] | |
| return sorted([p.name for p in base.iterdir() if p.is_dir()]) | |
| def _find_dataset_file(scenario_dir: Path) -> Optional[Path]: | |
| preferred = ["train_data.pt", "data.pt", "dataset.pt"] | |
| for name in preferred: | |
| cand = scenario_dir / name | |
| if cand.exists(): | |
| return cand | |
| for ext in ("*.pt", "*.p"): | |
| files = list(scenario_dir.glob(ext)) | |
| if files: | |
| return files[0] | |
| return None | |
| def list_demo_dataset_files(task: str, root: str = DEMO_ROOT) -> List[str]: | |
| out = [] | |
| base = Path(root) / task | |
| if not base.exists(): | |
| return out | |
| for scen in list_demo_scenarios(task, root): | |
| scen_dir = base / scen | |
| f = _find_dataset_file(scen_dir) | |
| if f is not None: | |
| out.append(str(f)) | |
| return out | |
| def load_pt_dataset(path: str) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
| obj = torch.load(path, map_location="cpu") | |
| ch = obj["channels"] | |
| y = obj.get("labels", None) | |
| return ch, y |