# github repo for our work
[FreeMask](https://github.com/apophis30/FreeMask)
# Intro
FreeSOLO, is introduced as a novel approach for self-supervised instance segmentation that does not require any type of annotations, including pixel-level or image-level labels.
The pipeline of FreeSOLO consists of two main components:
1.Free Mask-For each unlabeled image, the Free Mask approach generates coarse object masks with a ResNet-50-based backbone.
2.Self-Supervised SOLO-This component trains the SOLO-based instance segmenter using the coarse masks and semantic embeddings obtained from Free Mask. weakly-supervised design, self-training, and semantic embedding learning
## Mask Generation
### SOLOv2MaskHead
The SOLOv2MaskHead class represents the mask head component of the SOLOv2 model. It is responsible for generating instance masks given the input feature maps for SOLOv2 training.
```python
class SOLOv2MaskHead(nn.Module):
def __init__(self, cfg, input_shape: List[ShapeSpec]):
"""
SOLOv2 Mask Head.
Args:
cfg (CfgNode): The configuration node.
input_shape (List[ShapeSpec]): List of input feature shapes.
"""
super().__init__()
# Extract relevant configuration parameters
self.mask_on = cfg.MODEL.MASK_ON
self.num_masks = cfg.MODEL.SOLOV2.NUM_MASKS
self.mask_in_features = cfg.MODEL.SOLOV2.MASK_IN_FEATURES
self.mask_in_channels = cfg.MODEL.SOLOV2.MASK_IN_CHANNELS
self.mask_channels = cfg.MODEL.SOLOV2.MASK_CHANNELS
self.num_levels = len(input_shape)
# Initialize convolutional layers for each feature level
self.convs_all_levels = nn.ModuleList()
# ...
# Initialize convolutional layers for final mask prediction
self.conv_pred = nn.Sequential(
nn.Conv2d(
self.mask_channels, self.num_masks,
kernel_size=1, stride=1, padding=0, bias=norm is None
),
nn.GroupNorm(32, self.num_masks),
nn.ReLU(inplace=True)
)
# ...
def forward(self, features):
"""
Perform forward pass through the SOLOv2 Mask Head.
Args:
features (list[Tensor]): FPN feature map tensors in high to low resolution.
Each tensor in the list corresponds to different feature levels.
Returns:
mask_pred (Tensor): Predicted masks or segmentation maps.
"""
# ...
```
### Mask Generation Pipeline
#### Input
Unlabeled image I of size H × W.
#### Feature Extraction
Dense feature maps I ∈ R(H×W ×E) are extracted using a pre-trained backbone model (e.g., ResNet or other convolutional neural networks).
```python
# Feature Pyramid Network
class Model(nn.Module):
class ForwardInput(object):
class Train(NamedTuple):
image: Tensor
class Eval(NamedTuple):
image: Tensor
class ForwardOutput(object):
class Train(NamedTuple):
output: Tensor
class Eval(NamedTuple):
output: Tensor
def __init__(self, backbone, num_classes: int):
super().__init__()
# Initialize the backbone layers
def forward(self, forward_input: Union[ForwardInput.Train, ForwardInput.Eval]) -> Union[ForwardOutput.Train, ForwardOutput.Eval]:
# Forward pass implementation
def output_shape(self) -> Dict[str, ShapeSpec]:
# Define output shape for each feature map
# Create an instance of the Model
backbone = Model(backbone=ResNet50(pretrained=True), num_classes=None)
# Create a ForwardInput object with the image tensor
forward_input = Model.ForwardInput.Eval(image=image)
# Call the forward method with the ForwardInput object
features = backbone.forward(forward_input)
```
#### Queries and Keys
Queries Q and keys K are constructed from the extracted feature maps I. The queries Q ∈ R(H'×W'×E) are obtained by bilinearly downsampling I to a smaller spatial size, where H' and W' represent the downsampled dimensions. The keys K ∈ R(H×W×E) are the original feature maps I.
```python
mask_features = features['p5']
keys = mask_features.squeeze(0)
scale_factors = [1.0, 0.5, 0.25]
queries_list = []
for scale_factor in scale_factors:
cur_queries = F.interpolate(keys[None, ...], scale_factor=scale_factor, mode='bilinear')[0].reshape(keys.shape[0], -1).permute(1, 0)
num_q = len(cur_queries)
queries_list.append(cur_queries)
queries = torch.cat(queries_list)
_, H, W = keys.shape
keys = keys / keys.norm(dim=0, keepdim=True)
queries = queries / queries.norm(dim=1, keepdim=True)
```
#### Cosine Similarity
For each query in Q, we compute its cosine similarity with every key in K, thus obtaining the score maps.S = Q′ ⊛ K′.
Soft Masks: The score maps are normalized to soft masks by shifting the scores to the range [0, 1].
```python
#The attention map ('attn') is computed by taking the dot product between the normalized 'queries' and reshaped 'keys' tensors. This results in an attention map of shape [num_queries, H, W]
attn = queries @ keys.reshape(keys.shape[0], -1)
#attention map is normalized by subtracting the minimum value along the last dimension and dividing by the maximum value along the last dimension. This ensures that the values of the attention map are in the range [0, 1]
attn -= attn.min(-1, keepdim=True)[0]
attn /= attn.max(-1, keepdim=True)[0]
#attention map is reshaped to have the same number of queries (num_queries) and dimensions (H, W) as the original 'keys' tensor
attn = attn.reshape(attn.shape[0], H, W)
# soft mask
soft_masks = attn
#Thresholding is applied to obtain binary masks, where the threshold is set to 0.5.
masks = soft_masks >= 0.5 # binary masks
# downsample queries
queries = F.interpolate(queries[None, ...], size=128, mode='linear')[0]
#The sum_masks tensor is computed by summing the values of the masks along dimensions 1 and 2 (H and W). This provides the total number of "on" pixels for each mask.
sum_masks = masks.sum((1,2))
# boolean mask keep is created by comparing sum_masks with a threshold of 1. This mask selects the masks that have a sum greater than 1, indicating that they have at least one pixel classified as "on.
keep = sum_masks > 1
if keep.sum() == 0:
continue
masks = masks[keep]
soft_masks = soft_masks[keep]
sum_masks = sum_masks[keep]
queries = queries[keep]
```
#### Maskness Score
A maskness score is computed for each soft mask. The maskness score evaluates the quality or confidence of each extracted mask. It is calculated using the non-parametric maskness method, which measures the proportion of foreground pixels in a soft mask.
```python
maskness = (soft_masks * masks.float()).sum((1, 2)) / sum_masks
sort_inds = torch.argsort(maskness, descending=True)
maskness = maskness[sort_inds]
masks = masks[sort_inds]
sum_masks = sum_masks[sort_inds]
soft_masks = soft_masks[sort_inds]
queries = queries[sort_inds]
# Matrix NMS
#matrix_nms function is applied to refine the maskness scores (maskness) based on a matrix non-maximum suppression algorithm. This function takes maskness multiplied by 0 as an initial score, along with the masks, sum_masks, and maskness tensors as inputs, and applies matrix NMS with a specified sigma value and kernel type
maskness = matrix_nms(maskness*0, masks, sum_masks, maskness, sigma=2, kernel='gaussian')
```
#### Thresholding and NMS
The soft masks are converted into binary masks by applying a threshold τ. Masks with scores below the threshold are considered background, while those above the threshold are considered foreground. The binary masks are then sorted based on their maskness scores and redundant masks are removed using mask non-maximum suppression (NMS).
M = NMS(Maskness(Norm(Q′ ⊛ K′=S)))
```python
# NMS algorithm is a matrix based algorithm that considers the overlap between instance segmentation masks and performs score decay based on the IoU values and category labels. The decayed scores are used to suppress redundant instances during post-processing of the masks.
```
```python
sort_inds = torch.argsort(maskness, descending=True)
#The masks, maskness, soft_masks, and queries tensors are filtered based on the selected indices to keep only the top 20 masks and their corresponding information
if len(sort_inds) > 20:
sort_inds = sort_inds[:20]
masks = masks[sort_inds]
maskness = maskness[sort_inds]
soft_masks = soft_masks[sort_inds]
queries = queries[sort_inds]
soft_masks = F.interpolate(soft_masks[None, ...], size=(height, width), mode='bilinear')[0]
#A binary mask is created by thresholding the upsampled soft_masks tensor at 0.5, where values greater than or equal to 0.5 are set to 1 and values below 0.5 are set to 0. The resulting tensor is of type float
masks = (soft_masks >= 0.5).float()
sum_masks = masks.sum((1, 2))
```
#### Output
The final output consists of the binary masks M that represent the object masks generated by the Free Mask approach
```json
cur_ann_dict = {'segmentation': rle,
'bbox': boxes[idx],
'score': float(maskness[idx]),
'emb': queries[idx],
'iscrowd': 0,
'image_id': img_id,
'category_id': 1,
'id': ann_id}
```
## Training
The Self-Supervised SOLO model is trained using three main strategies:
#### Learning with coarse masks
The masks are treated as weak annotations, and the model is weakly supervised to align its predicted masks with the coarse masks.
#### Self-training
After the initial training with weak supervision, self-training is performed to refine the model. Unlabeled images are inputted into the instance segmenter, and the predicted object masks are collected. Low-confidence predictions are removed, and the remaining masks are treated as a new set of coarse masks. The model is trained again using these new masks, iteratively improving its performance
#### Semantic representation learning
In addition to instance segmentation, the model also learns semantic representations of objects.