File size: 1,968 Bytes
d82e7f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import os
import cv2
import torch
import numpy as np
from glob import glob
from PIL import Image


def torch_transform(image):
    image = image / 255.0
    image = np.transpose(image, (2, 0, 1))
    return image

def read_cv2_image(image_path):
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    return image

def read_mask(mask_path, shape):
    if not os.path.exists(mask_path):
        return np.ones(shape[1:]) > 0
    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
    mask = mask > 0
    return mask

def tensorize(array, model_dtype, device):
    array = torch.from_numpy(array).to(device).to(model_dtype).unsqueeze(dim=0)
    return array

def load_infer_data(config, device):
    image_dir = config['inference']['images']
    mask_dir = config['inference']['masks']

    image_paths = glob(os.path.join(image_dir, '*.png'))
    image_paths = sorted(image_paths)
    filenames = [os.path.basename(image_path)[:-4] for image_path in image_paths]
    cv2_images = [read_cv2_image(image_path) 
        for image_path in image_paths]
    PIL_images = [Image.fromarray(cv2_image) for cv2_image in cv2_images]
    images = [torch_transform(cv2_image) for cv2_image in cv2_images]

    mask_paths = [image_path.replace(image_dir, mask_dir) 
        for image_path in image_paths]
    masks = [read_mask(mask_path, images[i].shape) 
        for (i, mask_path) in enumerate(mask_paths)]

    model_dtype = config['spherevit']['dtype']
    images = [tensorize(image, model_dtype, device) for image in images]

    infer_data = {
        'images': {
            'PIL': PIL_images,
            'cv2': cv2_images,
            'torch': images
        },
        'masks': masks,
        'filenames': filenames,
        'size': len(images)
    }
    if config['env']['verbose']:
        s = 's' if len(images) > 1 else ''
        config['env']['logger'].info(f'Loaded {len(images)} image{s} in {model_dtype}')
    return infer_data