Spaces:
Runtime error
Runtime error
| from typing import Any, Dict, List, Union | |
| import numpy as np | |
| from PIL import Image | |
| def invert_mask(mask: np.ndarray) -> np.ndarray: | |
| """Invert mask. | |
| Args: | |
| mask (np.ndarray): mask | |
| Returns: | |
| np.ndarray: inverted mask | |
| """ | |
| if mask is None or not isinstance(mask, np.ndarray): | |
| raise ValueError("Invalid mask") | |
| # return np.logical_not(mask.astype(bool)).astype(np.uint8) * 255 | |
| return np.invert(mask.astype(np.uint8)) | |
| def check_inputs_create_mask_image( | |
| mask: Union[np.ndarray, Image.Image], | |
| sam_masks: List[Dict[str, Any]], | |
| ignore_black_chk: bool = True, | |
| ) -> None: | |
| """Check create mask image inputs. | |
| Args: | |
| mask (Union[np.ndarray, Image.Image]): mask | |
| sam_masks (List[Dict[str, Any]]): SAM masks | |
| ignore_black_chk (bool): ignore black check | |
| Returns: | |
| None | |
| """ | |
| if mask is None or not isinstance(mask, (np.ndarray, Image.Image)): | |
| raise ValueError("Invalid mask") | |
| if sam_masks is None or not isinstance(sam_masks, list): | |
| raise ValueError("Invalid SAM masks") | |
| if ignore_black_chk is None or not isinstance(ignore_black_chk, bool): | |
| raise ValueError("Invalid ignore black check") | |
| def convert_mask(mask: Union[np.ndarray, Image.Image]) -> np.ndarray: | |
| """Convert mask. | |
| Args: | |
| mask (Union[np.ndarray, Image.Image]): mask | |
| Returns: | |
| np.ndarray: converted mask | |
| """ | |
| if isinstance(mask, Image.Image): | |
| mask = np.array(mask) | |
| if mask.ndim == 2: | |
| mask = mask[:, :, np.newaxis] | |
| if mask.shape[2] != 1: | |
| mask = mask[:, :, 0:1] | |
| return mask | |
| def create_mask_image( | |
| mask: Union[np.ndarray, Image.Image], | |
| sam_masks: List[Dict[str, Any]], | |
| ignore_black_chk: bool = True, | |
| ) -> np.ndarray: | |
| """Create mask image. | |
| Args: | |
| mask (Union[np.ndarray, Image.Image]): mask | |
| sam_masks (List[Dict[str, Any]]): SAM masks | |
| ignore_black_chk (bool): ignore black check | |
| Returns: | |
| np.ndarray: mask image | |
| """ | |
| check_inputs_create_mask_image(mask, sam_masks, ignore_black_chk) | |
| mask = convert_mask(mask) | |
| canvas_image = np.zeros(mask.shape, dtype=np.uint8) | |
| mask_region = np.zeros(mask.shape, dtype=np.uint8) | |
| for seg_dict in sam_masks: | |
| seg_mask = np.expand_dims(seg_dict["segmentation"].astype(np.uint8), axis=-1) | |
| canvas_mask = np.logical_not(canvas_image.astype(bool)).astype(np.uint8) | |
| if (seg_mask * canvas_mask * mask).astype(bool).any(): | |
| mask_region = mask_region + (seg_mask * canvas_mask) | |
| seg_color = seg_mask * canvas_mask | |
| canvas_image = canvas_image + seg_color | |
| if not ignore_black_chk: | |
| canvas_mask = np.logical_not(canvas_image.astype(bool)).astype(np.uint8) | |
| if (canvas_mask * mask).astype(bool).any(): | |
| mask_region = mask_region + (canvas_mask) | |
| mask_region = np.tile(mask_region * 255, (1, 1, 3)) | |
| seg_image = mask_region.astype(np.uint8) | |
| return seg_image | |