import os, sys sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..')) import json import dataset import lightning.pytorch as pl from torch.utils.data import DataLoader from dataset.mydataset import MyDataset @dataset.register("dm_dipo") class DIPODataModule(pl.LightningDataModule): def __init__(self, hparams): super().__init__() self.hparams.update(hparams) def _prepare_split(self): with open(self.hparams.split_file , "r") as f: splits = json.load(f) train_ids = splits["train"] val_ids = [i for i in train_ids if "data" not in i] return train_ids, val_ids def _prepare_test_ids(self): if "acd" in self.hparams.get('test_which'): with open("/home/users/ruiqi.wu/singapo/data/data_acd.json", "r") as f: file = json.load(f) elif 'pm' in self.hparams.get('test_which'): with open(self.hparams.split_file, "r") as f: file = json.load(f) else: raise NotImplementedError(f"Dataset {self.hparams.get('test_which')} not implemented for SingapoDataModule") ids = file['test'] return ids def setup(self, stage=None): if stage == "fit" or stage is None: train_ids, val_ids = self._prepare_split() val_ids = val_ids self.train_dataset = MyDataset(self.hparams, model_ids=train_ids[:10], mode="train") self.val_dataset = MyDataset(self.hparams, model_ids=val_ids[:50], mode="val") elif stage == "validate": val_ids = self._prepare_test_ids() val_ids = val_ids self.val_dataset = MyDataset(self.hparams, model_ids=val_ids, mode="val") elif stage == "test": test_ids = self._prepare_test_ids() self.test_dataset = MyDataset(self.hparams, model_ids=test_ids, mode="test") else: raise NotImplementedError(f"Stage {stage} not implemented for SingapoDataModule") def train_dataloader(self): return DataLoader( self.train_dataset, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers, pin_memory=True, shuffle=True, persistent_workers=True ) def val_dataloader(self): return DataLoader( self.val_dataset, batch_size=128, num_workers=self.hparams.num_workers, pin_memory=True, shuffle=False, persistent_workers=True ) def test_dataloader(self): return DataLoader( self.test_dataset, batch_size=1, num_workers=self.hparams.num_workers, pin_memory=True, shuffle=False, persistent_workers=True )