DIPO / dataset /data_module.py
xinjie.wang
init commit
c28dddb
import os, sys
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
import json
import dataset
import lightning.pytorch as pl
from torch.utils.data import DataLoader
from dataset.mydataset import MyDataset
@dataset.register("dm_dipo")
class DIPODataModule(pl.LightningDataModule):
def __init__(self, hparams):
super().__init__()
self.hparams.update(hparams)
def _prepare_split(self):
with open(self.hparams.split_file , "r") as f:
splits = json.load(f)
train_ids = splits["train"]
val_ids = [i for i in train_ids if "data" not in i]
return train_ids, val_ids
def _prepare_test_ids(self):
if "acd" in self.hparams.get('test_which'):
with open("/home/users/ruiqi.wu/singapo/data/data_acd.json", "r") as f:
file = json.load(f)
elif 'pm' in self.hparams.get('test_which'):
with open(self.hparams.split_file, "r") as f:
file = json.load(f)
else:
raise NotImplementedError(f"Dataset {self.hparams.get('test_which')} not implemented for SingapoDataModule")
ids = file['test']
return ids
def setup(self, stage=None):
if stage == "fit" or stage is None:
train_ids, val_ids = self._prepare_split()
val_ids = val_ids
self.train_dataset = MyDataset(self.hparams, model_ids=train_ids[:10], mode="train")
self.val_dataset = MyDataset(self.hparams, model_ids=val_ids[:50], mode="val")
elif stage == "validate":
val_ids = self._prepare_test_ids()
val_ids = val_ids
self.val_dataset = MyDataset(self.hparams, model_ids=val_ids, mode="val")
elif stage == "test":
test_ids = self._prepare_test_ids()
self.test_dataset = MyDataset(self.hparams, model_ids=test_ids, mode="test")
else:
raise NotImplementedError(f"Stage {stage} not implemented for SingapoDataModule")
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.hparams.batch_size,
num_workers=self.hparams.num_workers,
pin_memory=True,
shuffle=True,
persistent_workers=True
)
def val_dataloader(self):
return DataLoader(
self.val_dataset,
batch_size=128,
num_workers=self.hparams.num_workers,
pin_memory=True,
shuffle=False,
persistent_workers=True
)
def test_dataloader(self):
return DataLoader(
self.test_dataset,
batch_size=1,
num_workers=self.hparams.num_workers,
pin_memory=True,
shuffle=False,
persistent_workers=True
)