File size: 2,852 Bytes
c28dddb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os, sys
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
import json
import dataset
import lightning.pytorch as pl
from torch.utils.data import DataLoader
from dataset.mydataset import MyDataset

@dataset.register("dm_dipo")
class DIPODataModule(pl.LightningDataModule):

    def __init__(self, hparams):
        super().__init__()
        self.hparams.update(hparams)

    def _prepare_split(self):
        with open(self.hparams.split_file , "r") as f:
            splits = json.load(f)

        train_ids = splits["train"]
        val_ids = [i for i in train_ids if "data" not in i]  
        return train_ids, val_ids

    def _prepare_test_ids(self):
        if "acd" in self.hparams.get('test_which'):
            with open("/home/users/ruiqi.wu/singapo/data/data_acd.json", "r") as f:
                file = json.load(f)
        elif 'pm' in self.hparams.get('test_which'):
            with open(self.hparams.split_file, "r") as f:
                file = json.load(f)
        else:
            raise NotImplementedError(f"Dataset {self.hparams.get('test_which')} not implemented for SingapoDataModule")
        ids = file['test']
        return ids
    
    def setup(self, stage=None):
        
        if stage == "fit" or stage is None:
            train_ids, val_ids = self._prepare_split()
            val_ids = val_ids
            self.train_dataset = MyDataset(self.hparams, model_ids=train_ids[:10], mode="train")
            self.val_dataset = MyDataset(self.hparams, model_ids=val_ids[:50], mode="val")
        elif stage == "validate":
            val_ids = self._prepare_test_ids()
            val_ids = val_ids
            self.val_dataset = MyDataset(self.hparams, model_ids=val_ids, mode="val")
        elif stage == "test":
            test_ids = self._prepare_test_ids()
            self.test_dataset = MyDataset(self.hparams, model_ids=test_ids, mode="test")
        else:
            raise NotImplementedError(f"Stage {stage} not implemented for SingapoDataModule")

    
    def train_dataloader(self):
        return DataLoader(
            self.train_dataset, 
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers,
            pin_memory=True,
            shuffle=True,
            persistent_workers=True
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=128,
            num_workers=self.hparams.num_workers,
            pin_memory=True,
            shuffle=False,
            persistent_workers=True
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=1,
            num_workers=self.hparams.num_workers,
            pin_memory=True,
            shuffle=False,
            persistent_workers=True
        )