Spaces:
Runtime error
Runtime error
| import os | |
| import numpy as np | |
| import torch | |
| import random | |
| from torchvision.datasets.folder import default_loader | |
| from diffusion.data.datasets.InternalData import InternalData, InternalDataSigma | |
| from diffusion.data.builder import get_data_path, DATASETS | |
| from diffusion.utils.logger import get_root_logger | |
| import torchvision.transforms as T | |
| from torchvision.transforms.functional import InterpolationMode | |
| from diffusion.data.datasets.utils import * | |
| def get_closest_ratio(height: float, width: float, ratios: dict): | |
| aspect_ratio = height / width | |
| closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio)) | |
| return ratios[closest_ratio], float(closest_ratio) | |
| class InternalDataMS(InternalData): | |
| def __init__(self, | |
| root, | |
| image_list_json='data_info.json', | |
| transform=None, | |
| resolution=256, | |
| sample_subset=None, | |
| load_vae_feat=False, | |
| input_size=32, | |
| patch_size=2, | |
| mask_ratio=0.0, | |
| mask_type='null', | |
| load_mask_index=False, | |
| real_prompt_ratio=1.0, | |
| max_length=120, | |
| config=None, | |
| **kwargs): | |
| self.root = get_data_path(root) | |
| self.transform = transform | |
| self.load_vae_feat = load_vae_feat | |
| self.ori_imgs_nums = 0 | |
| self.resolution = resolution | |
| self.N = int(resolution // (input_size // patch_size)) | |
| self.mask_ratio = mask_ratio | |
| self.load_mask_index = load_mask_index | |
| self.mask_type = mask_type | |
| self.real_prompt_ratio = real_prompt_ratio | |
| self.max_lenth = max_length | |
| self.base_size = int(kwargs['aspect_ratio_type'].split('_')[-1]) | |
| self.aspect_ratio = eval(kwargs.pop('aspect_ratio_type')) # base aspect ratio | |
| self.meta_data_clean = [] | |
| self.img_samples = [] | |
| self.txt_feat_samples = [] | |
| self.vae_feat_samples = [] | |
| self.mask_index_samples = [] | |
| self.ratio_index = {} | |
| self.ratio_nums = {} | |
| # self.weight_dtype = torch.float16 if self.real_prompt_ratio > 0 else torch.float32 | |
| for k, v in self.aspect_ratio.items(): | |
| self.ratio_index[float(k)] = [] # used for self.getitem | |
| self.ratio_nums[float(k)] = 0 # used for batch-sampler | |
| image_list_json = image_list_json if isinstance(image_list_json, list) else [image_list_json] | |
| for json_file in image_list_json: | |
| meta_data = self.load_json(os.path.join(self.root, json_file)) | |
| self.ori_imgs_nums += len(meta_data) | |
| meta_data_clean = [item for item in meta_data if item['ratio'] <= 4] | |
| self.meta_data_clean.extend(meta_data_clean) | |
| self.img_samples.extend([os.path.join(self.root.replace('InternData', "InternImgs"), item['path']) for item in meta_data_clean]) | |
| self.txt_feat_samples.extend([os.path.join(self.root, 'caption_features', '_'.join(item['path'].rsplit('/', 1)).replace('.png', '.npz')) for item in meta_data_clean]) | |
| self.vae_feat_samples.extend([os.path.join(self.root, f'img_vae_fatures_{resolution}_multiscale/ms', '_'.join(item['path'].rsplit('/', 1)).replace('.png', '.npy')) for item in meta_data_clean]) | |
| # Set loader and extensions | |
| if load_vae_feat: | |
| self.transform = None | |
| self.loader = self.vae_feat_loader | |
| else: | |
| self.loader = default_loader | |
| if sample_subset is not None: | |
| self.sample_subset(sample_subset) # sample dataset for local debug | |
| # scan the dataset for ratio static | |
| for i, info in enumerate(self.meta_data_clean[:len(self.meta_data_clean)//3]): | |
| ori_h, ori_w = info['height'], info['width'] | |
| closest_size, closest_ratio = get_closest_ratio(ori_h, ori_w, self.aspect_ratio) | |
| self.ratio_nums[closest_ratio] += 1 | |
| if len(self.ratio_index[closest_ratio]) == 0: | |
| self.ratio_index[closest_ratio].append(i) | |
| # print(self.ratio_nums) | |
| logger = get_root_logger() if config is None else get_root_logger(os.path.join(config.work_dir, 'train_log.log')) | |
| logger.info(f"T5 max token length: {self.max_lenth}") | |
| def getdata(self, index): | |
| img_path = self.img_samples[index] | |
| npz_path = self.txt_feat_samples[index] | |
| npy_path = self.vae_feat_samples[index] | |
| ori_h, ori_w = self.meta_data_clean[index]['height'], self.meta_data_clean[index]['width'] | |
| # Calculate the closest aspect ratio and resize & crop image[w, h] | |
| closest_size, closest_ratio = get_closest_ratio(ori_h, ori_w, self.aspect_ratio) | |
| closest_size = list(map(lambda x: int(x), closest_size)) | |
| self.closest_ratio = closest_ratio | |
| if self.load_vae_feat: | |
| try: | |
| img = self.loader(npy_path) | |
| if index not in self.ratio_index[closest_ratio]: | |
| self.ratio_index[closest_ratio].append(index) | |
| except Exception: | |
| index = random.choice(self.ratio_index[closest_ratio]) | |
| return self.getdata(index) | |
| h, w = (img.shape[1], img.shape[2]) | |
| assert h, w == (ori_h//8, ori_w//8) | |
| else: | |
| img = self.loader(img_path) | |
| h, w = (img.size[1], img.size[0]) | |
| assert h, w == (ori_h, ori_w) | |
| data_info = {'img_hw': torch.tensor([ori_h, ori_w], dtype=torch.float32)} | |
| data_info['aspect_ratio'] = closest_ratio | |
| data_info["mask_type"] = self.mask_type | |
| txt_info = np.load(npz_path) | |
| txt_fea = torch.from_numpy(txt_info['caption_feature']) | |
| attention_mask = torch.ones(1, 1, txt_fea.shape[1]) | |
| if 'attention_mask' in txt_info.keys(): | |
| attention_mask = torch.from_numpy(txt_info['attention_mask'])[None] | |
| if not self.load_vae_feat: | |
| if closest_size[0] / ori_h > closest_size[1] / ori_w: | |
| resize_size = closest_size[0], int(ori_w * closest_size[0] / ori_h) | |
| else: | |
| resize_size = int(ori_h * closest_size[1] / ori_w), closest_size[1] | |
| self.transform = T.Compose([ | |
| T.Lambda(lambda img: img.convert('RGB')), | |
| T.Resize(resize_size, interpolation=InterpolationMode.BICUBIC), # Image.BICUBIC | |
| T.CenterCrop(closest_size), | |
| T.ToTensor(), | |
| T.Normalize([.5], [.5]), | |
| ]) | |
| if self.transform: | |
| img = self.transform(img) | |
| return img, txt_fea, attention_mask, data_info | |
| def __getitem__(self, idx): | |
| for _ in range(20): | |
| try: | |
| return self.getdata(idx) | |
| except Exception as e: | |
| print(f"Error details: {str(e)}") | |
| idx = random.choice(self.ratio_index[self.closest_ratio]) | |
| raise RuntimeError('Too many bad data.') | |
| class InternalDataMSSigma(InternalDataSigma): | |
| def __init__(self, | |
| root, | |
| image_list_json='data_info.json', | |
| transform=None, | |
| resolution=256, | |
| sample_subset=None, | |
| load_vae_feat=False, | |
| load_t5_feat=False, | |
| input_size=32, | |
| patch_size=2, | |
| mask_ratio=0.0, | |
| mask_type='null', | |
| load_mask_index=False, | |
| real_prompt_ratio=1.0, | |
| max_length=300, | |
| config=None, | |
| **kwargs): | |
| self.root = get_data_path(root) | |
| self.transform = transform | |
| self.load_vae_feat = load_vae_feat | |
| self.load_t5_feat = load_t5_feat | |
| self.ori_imgs_nums = 0 | |
| self.resolution = resolution | |
| self.N = int(resolution // (input_size // patch_size)) | |
| self.mask_ratio = mask_ratio | |
| self.load_mask_index = load_mask_index | |
| self.mask_type = mask_type | |
| self.real_prompt_ratio = real_prompt_ratio | |
| self.max_lenth = max_length | |
| self.base_size = int(kwargs['aspect_ratio_type'].split('_')[-1]) | |
| self.aspect_ratio = eval(kwargs.pop('aspect_ratio_type')) # base aspect ratio | |
| self.meta_data_clean = [] | |
| self.img_samples = [] | |
| self.txt_samples = [] | |
| self.sharegpt4v_txt_samples = [] | |
| self.txt_feat_samples = [] | |
| self.vae_feat_samples = [] | |
| self.mask_index_samples = [] | |
| self.ratio_index = {} | |
| self.ratio_nums = {} | |
| self.gpt4v_txt_feat_samples = [] | |
| self.weight_dtype = torch.float16 if self.real_prompt_ratio > 0 else torch.float32 | |
| self.interpolate_model = InterpolationMode.BICUBIC | |
| if self.aspect_ratio in [ASPECT_RATIO_2048, ASPECT_RATIO_2880]: | |
| self.interpolate_model = InterpolationMode.LANCZOS | |
| suffix = '' | |
| for k, v in self.aspect_ratio.items(): | |
| self.ratio_index[float(k)] = [] # used for self.getitem | |
| self.ratio_nums[float(k)] = 0 # used for batch-sampler | |
| logger = get_root_logger() if config is None else get_root_logger(os.path.join(config.work_dir, 'train_log.log')) | |
| logger.info(f"T5 max token length: {self.max_lenth}") | |
| logger.info(f"ratio of real user prompt: {self.real_prompt_ratio}") | |
| image_list_json = image_list_json if isinstance(image_list_json, list) else [image_list_json] | |
| for json_file in image_list_json: | |
| meta_data = self.load_json(os.path.join(self.root, json_file)) | |
| logger.info(f"{json_file} data volume: {len(meta_data)}") | |
| self.ori_imgs_nums += len(meta_data) | |
| meta_data_clean = [item for item in meta_data if item['ratio'] <= 4.5] | |
| self.meta_data_clean.extend(meta_data_clean) | |
| self.img_samples.extend([ | |
| os.path.join(self.root.replace('InternData'+suffix, 'InternImgs'), item['path']) for item in meta_data_clean | |
| ]) | |
| self.txt_samples.extend([item['prompt'] for item in meta_data_clean]) | |
| self.sharegpt4v_txt_samples.extend([item['sharegpt4v'] if 'sharegpt4v' in item else '' for item in meta_data_clean]) | |
| self.txt_feat_samples.extend([ | |
| os.path.join( | |
| self.root, | |
| 'caption_features_new', | |
| '_'.join(item['path'].rsplit('/', 1)).replace('.png', '.npz') | |
| ) for item in meta_data_clean | |
| ]) | |
| self.gpt4v_txt_feat_samples.extend([ | |
| os.path.join( | |
| self.root, | |
| 'sharegpt4v_caption_features_new', | |
| '_'.join(item['path'].rsplit('/', 1)).replace('.png', '.npz') | |
| ) for item in meta_data_clean | |
| ]) | |
| self.vae_feat_samples.extend( | |
| [ | |
| os.path.join( | |
| self.root + suffix, | |
| f'img_sdxl_vae_features_{resolution}resolution_ms_new', | |
| '_'.join(item['path'].rsplit('/', 1)).replace('.png', '.npy') | |
| ) for item in meta_data_clean | |
| ]) | |
| if self.real_prompt_ratio < 1: | |
| assert len(self.sharegpt4v_txt_samples[0]) != 0 | |
| # Set loader and extensions | |
| if load_vae_feat: | |
| self.transform = None | |
| self.loader = self.vae_feat_loader | |
| else: | |
| self.loader = default_loader | |
| if sample_subset is not None: | |
| self.sample_subset(sample_subset) # sample dataset for local debug | |
| # scan the dataset for ratio static | |
| for i, info in enumerate(self.meta_data_clean[:len(self.meta_data_clean)//3]): | |
| ori_h, ori_w = info['height'], info['width'] | |
| closest_size, closest_ratio = get_closest_ratio(ori_h, ori_w, self.aspect_ratio) | |
| self.ratio_nums[closest_ratio] += 1 | |
| if len(self.ratio_index[closest_ratio]) == 0: | |
| self.ratio_index[closest_ratio].append(i) | |
| def getdata(self, index): | |
| img_path = self.img_samples[index] | |
| real_prompt = random.random() < self.real_prompt_ratio | |
| npz_path = self.txt_feat_samples[index] if real_prompt else self.gpt4v_txt_feat_samples[index] | |
| txt = self.txt_samples[index] if real_prompt else self.sharegpt4v_txt_samples[index] | |
| npy_path = self.vae_feat_samples[index] | |
| data_info = {} | |
| ori_h, ori_w = self.meta_data_clean[index]['height'], self.meta_data_clean[index]['width'] | |
| # Calculate the closest aspect ratio and resize & crop image[w, h] | |
| closest_size, closest_ratio = get_closest_ratio(ori_h, ori_w, self.aspect_ratio) | |
| closest_size = list(map(lambda x: int(x), closest_size)) | |
| self.closest_ratio = closest_ratio | |
| if self.load_vae_feat: | |
| img = self.loader(npy_path) | |
| if index not in self.ratio_index[closest_ratio]: | |
| self.ratio_index[closest_ratio].append(index) | |
| h, w = (img.shape[1], img.shape[2]) | |
| assert h, w == (ori_h//8, ori_w//8) | |
| else: | |
| img = self.loader(img_path) | |
| h, w = (img.size[1], img.size[0]) | |
| assert h, w == (ori_h, ori_w) | |
| data_info['img_hw'] = torch.tensor([ori_h, ori_w], dtype=torch.float32) | |
| data_info['aspect_ratio'] = closest_ratio | |
| data_info["mask_type"] = self.mask_type | |
| attention_mask = torch.ones(1, 1, self.max_lenth) | |
| if self.load_t5_feat: | |
| txt_info = np.load(npz_path) | |
| txt_fea = torch.from_numpy(txt_info['caption_feature']) | |
| if 'attention_mask' in txt_info.keys(): | |
| attention_mask = torch.from_numpy(txt_info['attention_mask'])[None] | |
| if txt_fea.shape[1] != self.max_lenth: | |
| txt_fea = torch.cat([txt_fea, txt_fea[:, -1:].repeat(1, self.max_lenth-txt_fea.shape[1], 1)], dim=1).to(self.weight_dtype) | |
| attention_mask = torch.cat([attention_mask, torch.zeros(1, 1, self.max_lenth-attention_mask.shape[-1])], dim=-1) | |
| else: | |
| txt_fea = txt | |
| if not self.load_vae_feat: | |
| if closest_size[0] / ori_h > closest_size[1] / ori_w: | |
| resize_size = closest_size[0], int(ori_w * closest_size[0] / ori_h) | |
| else: | |
| resize_size = int(ori_h * closest_size[1] / ori_w), closest_size[1] | |
| self.transform = T.Compose([ | |
| T.Lambda(lambda img: img.convert('RGB')), | |
| T.Resize(resize_size, interpolation=self.interpolate_model), # Image.BICUBIC | |
| T.CenterCrop(closest_size), | |
| T.ToTensor(), | |
| T.Normalize([.5], [.5]), | |
| ]) | |
| if self.transform: | |
| img = self.transform(img) | |
| return img, txt_fea, attention_mask.to(torch.int16), data_info | |
| def __getitem__(self, idx): | |
| for _ in range(20): | |
| try: | |
| data = self.getdata(idx) | |
| return data | |
| except Exception as e: | |
| print(f"Error details: {str(e)}") | |
| idx = random.choice(self.ratio_index[self.closest_ratio]) | |
| raise RuntimeError('Too many bad data.') | |