Spaces:
Sleeping
Sleeping
File size: 1,422 Bytes
3c27f51 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
from pathlib import Path
from typing import List, Optional, Tuple
import torch
DEMO_ROOT = "data"
def list_demo_tasks(root: str = DEMO_ROOT) -> List[str]:
r = Path(root)
if not r.exists():
return []
return sorted([p.name for p in r.iterdir() if p.is_dir()])
def list_demo_scenarios(task: str, root: str = DEMO_ROOT) -> List[str]:
base = Path(root) / task
if not base.exists():
return []
return sorted([p.name for p in base.iterdir() if p.is_dir()])
def _find_dataset_file(scenario_dir: Path) -> Optional[Path]:
preferred = ["train_data.pt", "data.pt", "dataset.pt"]
for name in preferred:
cand = scenario_dir / name
if cand.exists():
return cand
for ext in ("*.pt", "*.p"):
files = list(scenario_dir.glob(ext))
if files:
return files[0]
return None
def list_demo_dataset_files(task: str, root: str = DEMO_ROOT) -> List[str]:
out = []
base = Path(root) / task
if not base.exists():
return out
for scen in list_demo_scenarios(task, root):
scen_dir = base / scen
f = _find_dataset_file(scen_dir)
if f is not None:
out.append(str(f))
return out
def load_pt_dataset(path: str) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
obj = torch.load(path, map_location="cpu")
ch = obj["channels"]
y = obj.get("labels", None)
return ch, y |