DA-2 / da2 /utils /io.py
haodongli's picture
update
d82e7f9
import os
import cv2
import torch
import numpy as np
from glob import glob
from PIL import Image
def torch_transform(image):
image = image / 255.0
image = np.transpose(image, (2, 0, 1))
return image
def read_cv2_image(image_path):
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
return image
def read_mask(mask_path, shape):
if not os.path.exists(mask_path):
return np.ones(shape[1:]) > 0
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
mask = mask > 0
return mask
def tensorize(array, model_dtype, device):
array = torch.from_numpy(array).to(device).to(model_dtype).unsqueeze(dim=0)
return array
def load_infer_data(config, device):
image_dir = config['inference']['images']
mask_dir = config['inference']['masks']
image_paths = glob(os.path.join(image_dir, '*.png'))
image_paths = sorted(image_paths)
filenames = [os.path.basename(image_path)[:-4] for image_path in image_paths]
cv2_images = [read_cv2_image(image_path)
for image_path in image_paths]
PIL_images = [Image.fromarray(cv2_image) for cv2_image in cv2_images]
images = [torch_transform(cv2_image) for cv2_image in cv2_images]
mask_paths = [image_path.replace(image_dir, mask_dir)
for image_path in image_paths]
masks = [read_mask(mask_path, images[i].shape)
for (i, mask_path) in enumerate(mask_paths)]
model_dtype = config['spherevit']['dtype']
images = [tensorize(image, model_dtype, device) for image in images]
infer_data = {
'images': {
'PIL': PIL_images,
'cv2': cv2_images,
'torch': images
},
'masks': masks,
'filenames': filenames,
'size': len(images)
}
if config['env']['verbose']:
s = 's' if len(images) > 1 else ''
config['env']['logger'].info(f'Loaded {len(images)} image{s} in {model_dtype}')
return infer_data