Last active
May 26, 2023 17:07
-
-
Save cip8/64a7503b079fbf4dc6ab5f807702838f to your computer and use it in GitHub Desktop.
Custom mask to SAM
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_grid_points(self, mask_arr: np.ndarray, pad_ratio: int = 50) -> np.ndarray: | |
""" | |
Returns a grid of points that are in the foreground of the given binary mask. | |
The padding between points in the grid is adjusted based on the size of the mask. | |
Parameters: | |
- mask: A 2D binary numpy array, where 1 represents foreground and 0 represents background. | |
- pad_ratio: Scaling factor (divisor) used to compute the padding between points in the grid. Larger values result in smaller padding. | |
Returns: | |
- A 2D numpy array where each row is a point (x, y) in the grid. | |
""" | |
# Convert mask to boolean values. | |
mask_arr = (mask_arr != 0).astype(bool) | |
# Calculate the padding based on the size of the mask | |
padding = int(np.sqrt(mask_arr.size) / pad_ratio) | |
# Create a grid of points spaced out by the padding | |
grid_y, grid_x = np.mgrid[ | |
0 : mask_arr.shape[0] : padding, 0 : mask_arr.shape[1] : padding | |
] | |
# Flatten the grid arrays and stack them into a 2D array of points | |
points = np.vstack((grid_x.ravel(), grid_y.ravel())).T | |
# Select only the points that are in the foreground | |
foreground_points = points[mask_arr[points[:, 1], points[:, 0]] == 1] | |
return foreground_points | |
def resize_mask( | |
self, ref_mask: np.ndarray, longest_side: int = 256 | |
) -> tuple[np.ndarray, int, int]: | |
""" | |
Resize an image to have its longest side equal to the specified value. | |
Args: | |
ref_mask (np.ndarray): The image to be resized. | |
longest_side (int, optional): The length of the longest side after resizing. Default is 256. | |
Returns: | |
tuple[np.ndarray, int, int]: The resized image and its new height and width. | |
""" | |
height, width = ref_mask.shape[:2] | |
if height > width: | |
new_height = longest_side | |
new_width = int(width * (new_height / height)) | |
else: | |
new_width = longest_side | |
new_height = int(height * (new_width / width)) | |
return ( | |
cv2.resize( | |
ref_mask, (new_width, new_height), interpolation=cv2.INTER_NEAREST | |
), | |
new_height, | |
new_width, | |
) | |
def pad_mask( | |
self, | |
ref_mask: np.ndarray, | |
new_height: int, | |
new_width: int, | |
pad_all_sides: bool = False, | |
) -> np.ndarray: | |
""" | |
Add padding to an image to make it square. | |
Args: | |
ref_mask (np.ndarray): The image to be padded. | |
new_height (int): The height of the image after resizing. | |
new_width (int): The width of the image after resizing. | |
pad_all_sides (bool, optional): Whether to pad all sides of the image equally. If False, padding will be added to the bottom and right sides. Default is False. | |
Returns: | |
np.ndarray: The padded image. | |
""" | |
pad_height = 256 - new_height | |
pad_width = 256 - new_width | |
if pad_all_sides: | |
padding = ( | |
(pad_height // 2, pad_height - pad_height // 2), | |
(pad_width // 2, pad_width - pad_width // 2), | |
) | |
else: | |
padding = ((0, pad_height), (0, pad_width)) | |
# Padding value defaults to '0' when the `np.pad`` mode is set to 'constant'. | |
return np.pad(ref_mask, padding, mode="constant") | |
def reference_to_sam_mask( | |
self, ref_mask: np.ndarray, threshold: int = 127, pad_all_sides: bool = False | |
) -> np.ndarray: | |
""" | |
Convert a grayscale mask to a binary mask, resize it to have its longest side equal to 256, and add padding to make it square. | |
Args: | |
ref_mask (np.ndarray): The grayscale mask to be processed. | |
threshold (int, optional): The threshold value for the binarization. Default is 127. | |
pad_all_sides (bool, optional): Whether to pad all sides of the image equally. If False, padding will be added to the bottom and right sides. Default is False. | |
Returns: | |
np.ndarray: The processed binary mask. | |
""" | |
# Convert a grayscale mask to a binary mask. | |
# Values over the threshold are set to 1, values below are set to -1. | |
ref_mask = np.clip((ref_mask > threshold) * 2 - 1, -1, 1) # type: ignore | |
# Resize to have the longest side 256. | |
resized_mask, new_height, new_width = self.resize_mask(ref_mask) | |
# Add padding to make it square. | |
square_mask = self.pad_mask(resized_mask, new_height, new_width, pad_all_sides) | |
return square_mask | |
[...] | |
# Obtain SAM compatible mask. | |
sam_mask: np.ndarray = self.reference_to_sam_mask(ref_mask) | |
# Initialize SAM predictor and set the image. | |
predictor: SamPredictor = SamPredictor(self._models["sam"]) | |
predictor.set_image(img_arr) # bbox cut image! | |
# Expand SAM mask's dimensions to 1xHxW (1x256x256). | |
sam_mask = np.expand_dims(sam_mask, axis=0) | |
# Run SAM predictor. | |
masks, scores, logits = predictor.predict( | |
multimask_output=True, | |
point_coords=input_points, | |
point_labels=np.ones(len(input_points)), | |
mask_input=sam_mask, | |
) | |
[...] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment