Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,968 Bytes
d82e7f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import os
import cv2
import torch
import numpy as np
from glob import glob
from PIL import Image
def torch_transform(image):
image = image / 255.0
image = np.transpose(image, (2, 0, 1))
return image
def read_cv2_image(image_path):
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
return image
def read_mask(mask_path, shape):
if not os.path.exists(mask_path):
return np.ones(shape[1:]) > 0
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
mask = mask > 0
return mask
def tensorize(array, model_dtype, device):
array = torch.from_numpy(array).to(device).to(model_dtype).unsqueeze(dim=0)
return array
def load_infer_data(config, device):
image_dir = config['inference']['images']
mask_dir = config['inference']['masks']
image_paths = glob(os.path.join(image_dir, '*.png'))
image_paths = sorted(image_paths)
filenames = [os.path.basename(image_path)[:-4] for image_path in image_paths]
cv2_images = [read_cv2_image(image_path)
for image_path in image_paths]
PIL_images = [Image.fromarray(cv2_image) for cv2_image in cv2_images]
images = [torch_transform(cv2_image) for cv2_image in cv2_images]
mask_paths = [image_path.replace(image_dir, mask_dir)
for image_path in image_paths]
masks = [read_mask(mask_path, images[i].shape)
for (i, mask_path) in enumerate(mask_paths)]
model_dtype = config['spherevit']['dtype']
images = [tensorize(image, model_dtype, device) for image in images]
infer_data = {
'images': {
'PIL': PIL_images,
'cv2': cv2_images,
'torch': images
},
'masks': masks,
'filenames': filenames,
'size': len(images)
}
if config['env']['verbose']:
s = 's' if len(images) > 1 else ''
config['env']['logger'].info(f'Loaded {len(images)} image{s} in {model_dtype}')
return infer_data
|