Spaces:
Runtime error
Runtime error
Commit
·
4d5065f
1
Parent(s):
562c833
upload
Browse files- util/__init__.py +0 -0
- util/batchsize.py +59 -0
- util/image_util.py +172 -0
- util/seed_all.py +13 -0
util/__init__.py
ADDED
|
File without changes
|
util/batchsize.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
# Search table for suggested max. inference batch size
|
| 6 |
+
bs_search_table = [
|
| 7 |
+
# tested on A100-PCIE-80GB
|
| 8 |
+
{"res": 768, "total_vram": 79, "bs": 35, "dtype": torch.float32},
|
| 9 |
+
{"res": 1024, "total_vram": 79, "bs": 20, "dtype": torch.float32},
|
| 10 |
+
# tested on A100-PCIE-40GB
|
| 11 |
+
{"res": 768, "total_vram": 39, "bs": 15, "dtype": torch.float32},
|
| 12 |
+
{"res": 1024, "total_vram": 39, "bs": 8, "dtype": torch.float32},
|
| 13 |
+
{"res": 768, "total_vram": 39, "bs": 30, "dtype": torch.float16},
|
| 14 |
+
{"res": 1024, "total_vram": 39, "bs": 15, "dtype": torch.float16},
|
| 15 |
+
# tested on RTX3090, RTX4090
|
| 16 |
+
{"res": 512, "total_vram": 23, "bs": 20, "dtype": torch.float32},
|
| 17 |
+
{"res": 768, "total_vram": 23, "bs": 7, "dtype": torch.float32},
|
| 18 |
+
{"res": 1024, "total_vram": 23, "bs": 3, "dtype": torch.float32},
|
| 19 |
+
{"res": 512, "total_vram": 23, "bs": 40, "dtype": torch.float16},
|
| 20 |
+
{"res": 768, "total_vram": 23, "bs": 18, "dtype": torch.float16},
|
| 21 |
+
{"res": 1024, "total_vram": 23, "bs": 10, "dtype": torch.float16},
|
| 22 |
+
# tested on GTX1080Ti
|
| 23 |
+
{"res": 512, "total_vram": 10, "bs": 5, "dtype": torch.float32},
|
| 24 |
+
{"res": 768, "total_vram": 10, "bs": 2, "dtype": torch.float32},
|
| 25 |
+
{"res": 512, "total_vram": 10, "bs": 10, "dtype": torch.float16},
|
| 26 |
+
{"res": 768, "total_vram": 10, "bs": 5, "dtype": torch.float16},
|
| 27 |
+
{"res": 1024, "total_vram": 10, "bs": 3, "dtype": torch.float16},
|
| 28 |
+
]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int:
|
| 32 |
+
"""
|
| 33 |
+
Automatically search for suitable operating batch size.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
ensemble_size (int): Number of predictions to be ensembled
|
| 37 |
+
input_res (int): Operating resolution of the input image.
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
int: Operating batch size
|
| 41 |
+
"""
|
| 42 |
+
if not torch.cuda.is_available():
|
| 43 |
+
return 1
|
| 44 |
+
|
| 45 |
+
total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3
|
| 46 |
+
filtered_bs_search_table = [s for s in bs_search_table if s["dtype"] == dtype]
|
| 47 |
+
for settings in sorted(
|
| 48 |
+
filtered_bs_search_table,
|
| 49 |
+
key=lambda k: (k["res"], -k["total_vram"]),
|
| 50 |
+
):
|
| 51 |
+
if input_res <= settings["res"] and total_vram >= settings["total_vram"]:
|
| 52 |
+
bs = settings["bs"]
|
| 53 |
+
if bs > ensemble_size:
|
| 54 |
+
bs = ensemble_size
|
| 55 |
+
elif bs > math.ceil(ensemble_size / 2) and bs < ensemble_size:
|
| 56 |
+
bs = math.ceil(ensemble_size / 2)
|
| 57 |
+
return bs
|
| 58 |
+
|
| 59 |
+
return 1
|
util/image_util.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from torchvision import transforms
|
| 6 |
+
|
| 7 |
+
def norm_to_rgb(norm):
|
| 8 |
+
# norm: (3, H, W), range from [-1, 1]
|
| 9 |
+
norm_rgb = ((norm + 1) * 0.5) * 255
|
| 10 |
+
norm_rgb = np.clip(norm_rgb, a_min=0, a_max=255)
|
| 11 |
+
norm_rgb = norm_rgb.astype(np.uint8)
|
| 12 |
+
return norm_rgb
|
| 13 |
+
|
| 14 |
+
def colorize_depth_maps(
|
| 15 |
+
depth_map, min_depth, max_depth, cmap="Spectral", valid_mask=None
|
| 16 |
+
):
|
| 17 |
+
"""
|
| 18 |
+
Colorize depth maps.
|
| 19 |
+
"""
|
| 20 |
+
assert len(depth_map.shape) >= 2, "Invalid dimension"
|
| 21 |
+
|
| 22 |
+
if isinstance(depth_map, torch.Tensor):
|
| 23 |
+
depth = depth_map.detach().clone().squeeze().numpy()
|
| 24 |
+
elif isinstance(depth_map, np.ndarray):
|
| 25 |
+
depth = np.squeeze(depth_map.copy())
|
| 26 |
+
# reshape to [ (B,) H, W ]
|
| 27 |
+
if depth.ndim < 3:
|
| 28 |
+
depth = depth[np.newaxis, :, :]
|
| 29 |
+
|
| 30 |
+
# colorize
|
| 31 |
+
cm = matplotlib.colormaps[cmap]
|
| 32 |
+
depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1)
|
| 33 |
+
img_colored_np = cm(depth, bytes=False)[:, :, :, 0:3] # value from 0 to 1
|
| 34 |
+
img_colored_np = np.rollaxis(img_colored_np, 3, 1)
|
| 35 |
+
|
| 36 |
+
if valid_mask is not None:
|
| 37 |
+
if isinstance(depth_map, torch.Tensor):
|
| 38 |
+
valid_mask = valid_mask.detach().numpy()
|
| 39 |
+
valid_mask = np.squeeze(valid_mask) # [H, W] or [B, H, W]
|
| 40 |
+
if valid_mask.ndim < 3:
|
| 41 |
+
valid_mask = valid_mask[np.newaxis, np.newaxis, :, :]
|
| 42 |
+
else:
|
| 43 |
+
valid_mask = valid_mask[:, np.newaxis, :, :]
|
| 44 |
+
valid_mask = np.repeat(valid_mask, 3, axis=1)
|
| 45 |
+
img_colored_np[~valid_mask] = 0
|
| 46 |
+
|
| 47 |
+
if isinstance(depth_map, torch.Tensor):
|
| 48 |
+
img_colored = torch.from_numpy(img_colored_np).float()
|
| 49 |
+
elif isinstance(depth_map, np.ndarray):
|
| 50 |
+
img_colored = img_colored_np
|
| 51 |
+
|
| 52 |
+
return img_colored
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def chw2hwc(chw):
|
| 56 |
+
assert 3 == len(chw.shape)
|
| 57 |
+
if isinstance(chw, torch.Tensor):
|
| 58 |
+
hwc = torch.permute(chw, (1, 2, 0))
|
| 59 |
+
elif isinstance(chw, np.ndarray):
|
| 60 |
+
hwc = np.moveaxis(chw, 0, -1)
|
| 61 |
+
return hwc
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def resize_max_res(img: Image.Image, max_edge_resolution: int) -> Image.Image:
|
| 65 |
+
"""
|
| 66 |
+
Resize image to limit maximum edge length while keeping aspect ratio
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
img (Image.Image): Image to be resized
|
| 70 |
+
max_edge_resolution (int): Maximum edge length (px).
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
Image.Image: Resized image.
|
| 74 |
+
"""
|
| 75 |
+
original_width, original_height = img.size
|
| 76 |
+
downscale_factor = min(
|
| 77 |
+
max_edge_resolution / original_width, max_edge_resolution / original_height
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
new_width = int(original_width * downscale_factor)
|
| 81 |
+
new_height = int(original_height * downscale_factor)
|
| 82 |
+
|
| 83 |
+
resized_img = img.resize((new_width, new_height))
|
| 84 |
+
return resized_img
|
| 85 |
+
|
| 86 |
+
def resize_max_res_integer_16(img: Image.Image, max_edge_resolution: int) -> Image.Image:
|
| 87 |
+
"""
|
| 88 |
+
Resize image to limit maximum edge length while keeping aspect ratio
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
img (Image.Image): Image to be resized
|
| 92 |
+
max_edge_resolution (int): Maximum edge length (px).
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
Image.Image: Resized image.
|
| 96 |
+
"""
|
| 97 |
+
original_width, original_height = img.size
|
| 98 |
+
downscale_factor = min(
|
| 99 |
+
max_edge_resolution / original_width, max_edge_resolution / original_height
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
new_width = int(original_width * downscale_factor) // 16 * 16 # make sure it is integer multiples of 16, used for pixart
|
| 103 |
+
new_height = int(original_height * downscale_factor) // 16 * 16 # make sure it is integer multiples of 16, used for pixart
|
| 104 |
+
|
| 105 |
+
resized_img = img.resize((new_width, new_height))
|
| 106 |
+
return resized_img
|
| 107 |
+
|
| 108 |
+
def resize_res(img: Image.Image, max_edge_resolution: int) -> Image.Image:
|
| 109 |
+
"""
|
| 110 |
+
Resize image to limit maximum edge length while keeping aspect ratio
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
img (Image.Image): Image to be resized
|
| 114 |
+
max_edge_resolution (int): Maximum edge length (px).
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
Image.Image: Resized image.
|
| 118 |
+
"""
|
| 119 |
+
|
| 120 |
+
resized_img = img.resize((max_edge_resolution, max_edge_resolution))
|
| 121 |
+
return resized_img
|
| 122 |
+
|
| 123 |
+
class ResizeLongestEdge:
|
| 124 |
+
def __init__(self, max_size, interpolation=transforms.InterpolationMode.BILINEAR):
|
| 125 |
+
self.max_size = max_size
|
| 126 |
+
self.interpolation = interpolation
|
| 127 |
+
|
| 128 |
+
def __call__(self, img):
|
| 129 |
+
|
| 130 |
+
scale = self.max_size / max(img.width, img.height)
|
| 131 |
+
new_size = (int(img.height * scale), int(img.width * scale))
|
| 132 |
+
|
| 133 |
+
return transforms.functional.resize(img, new_size, self.interpolation)
|
| 134 |
+
|
| 135 |
+
class ResizeShortestEdge:
|
| 136 |
+
def __init__(self, min_size, interpolation=transforms.InterpolationMode.BILINEAR):
|
| 137 |
+
self.min_size = min_size
|
| 138 |
+
self.interpolation = interpolation
|
| 139 |
+
|
| 140 |
+
def __call__(self, img):
|
| 141 |
+
|
| 142 |
+
scale = self.min_size / min(img.width, img.height)
|
| 143 |
+
new_size = (int(img.height * scale), int(img.width * scale))
|
| 144 |
+
|
| 145 |
+
return transforms.functional.resize(img, new_size, self.interpolation)
|
| 146 |
+
|
| 147 |
+
class ResizeHard:
|
| 148 |
+
def __init__(self, size, interpolation=transforms.InterpolationMode.BILINEAR):
|
| 149 |
+
self.size = size
|
| 150 |
+
self.interpolation = interpolation
|
| 151 |
+
|
| 152 |
+
def __call__(self, img):
|
| 153 |
+
|
| 154 |
+
new_size = (int(self.size), int(self.size))
|
| 155 |
+
|
| 156 |
+
return transforms.functional.resize(img, new_size, self.interpolation)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class ResizeLongestEdgeInteger:
|
| 160 |
+
def __init__(self, max_size, interpolation=transforms.InterpolationMode.BILINEAR, integer=16):
|
| 161 |
+
self.max_size = max_size
|
| 162 |
+
self.interpolation = interpolation
|
| 163 |
+
self.integer = integer
|
| 164 |
+
|
| 165 |
+
def __call__(self, img):
|
| 166 |
+
|
| 167 |
+
scale = self.max_size / max(img.width, img.height)
|
| 168 |
+
new_size_h = int(img.height * scale) // self.integer * self.integer
|
| 169 |
+
new_size_w = int(img.width * scale) // self.integer * self.integer
|
| 170 |
+
new_size = (new_size_h, new_size_w)
|
| 171 |
+
|
| 172 |
+
return transforms.functional.resize(img, new_size, self.interpolation)
|
util/seed_all.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import random
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def seed_all(seed: int = 0):
|
| 7 |
+
"""
|
| 8 |
+
Set random seeds of all components.
|
| 9 |
+
"""
|
| 10 |
+
random.seed(seed)
|
| 11 |
+
np.random.seed(seed)
|
| 12 |
+
torch.manual_seed(seed)
|
| 13 |
+
torch.cuda.manual_seed_all(seed)
|