Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import cv2 | |
| import torch | |
| import numpy as np | |
| from glob import glob | |
| from PIL import Image | |
| def torch_transform(image): | |
| image = image / 255.0 | |
| image = np.transpose(image, (2, 0, 1)) | |
| return image | |
| def read_cv2_image(image_path): | |
| image = cv2.imread(image_path) | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| return image | |
| def read_mask(mask_path, shape): | |
| if not os.path.exists(mask_path): | |
| return np.ones(shape[1:]) > 0 | |
| mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) | |
| mask = mask > 0 | |
| return mask | |
| def tensorize(array, model_dtype, device): | |
| array = torch.from_numpy(array).to(device).to(model_dtype).unsqueeze(dim=0) | |
| return array | |
| def load_infer_data(config, device): | |
| image_dir = config['inference']['images'] | |
| mask_dir = config['inference']['masks'] | |
| image_paths = glob(os.path.join(image_dir, '*.png')) | |
| image_paths = sorted(image_paths) | |
| filenames = [os.path.basename(image_path)[:-4] for image_path in image_paths] | |
| cv2_images = [read_cv2_image(image_path) | |
| for image_path in image_paths] | |
| PIL_images = [Image.fromarray(cv2_image) for cv2_image in cv2_images] | |
| images = [torch_transform(cv2_image) for cv2_image in cv2_images] | |
| mask_paths = [image_path.replace(image_dir, mask_dir) | |
| for image_path in image_paths] | |
| masks = [read_mask(mask_path, images[i].shape) | |
| for (i, mask_path) in enumerate(mask_paths)] | |
| model_dtype = config['spherevit']['dtype'] | |
| images = [tensorize(image, model_dtype, device) for image in images] | |
| infer_data = { | |
| 'images': { | |
| 'PIL': PIL_images, | |
| 'cv2': cv2_images, | |
| 'torch': images | |
| }, | |
| 'masks': masks, | |
| 'filenames': filenames, | |
| 'size': len(images) | |
| } | |
| if config['env']['verbose']: | |
| s = 's' if len(images) > 1 else '' | |
| config['env']['logger'].info(f'Loaded {len(images)} image{s} in {model_dtype}') | |
| return infer_data | |