1049 lines
44 KiB
Python
1049 lines
44 KiB
Python
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||
"""Model head modules."""
|
||
|
||
import copy
|
||
import math
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
from torch.nn.init import constant_, xavier_uniform_
|
||
|
||
from ultralytics.utils.tal import TORCH_1_10, dist2bbox, dist2rbox, make_anchors
|
||
|
||
from .block import DFL, BNContrastiveHead, ContrastiveHead, Proto
|
||
from .conv import Conv
|
||
from .transformer import MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer, TransformerSegmentationDecoder
|
||
from .utils import bias_init_with_prob, linear_init
|
||
|
||
|
||
import numpy as np
|
||
import torch.nn.functional as F
|
||
import matplotlib.pyplot as plt
|
||
|
||
|
||
__all__ = "Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder", "v10Detect", "MTDETRDecoder"
|
||
|
||
|
||
class Detect(nn.Module):
|
||
"""YOLOv8 Detect head for detection models."""
|
||
|
||
dynamic = False # force grid reconstruction
|
||
export = False # export mode
|
||
end2end = False # end2end
|
||
max_det = 300 # max_det
|
||
shape = None
|
||
anchors = torch.empty(0) # init
|
||
strides = torch.empty(0) # init
|
||
|
||
def __init__(self, nc=80, ch=()):
|
||
"""Initializes the YOLOv8 detection layer with specified number of classes and channels."""
|
||
super().__init__()
|
||
self.nc = nc # number of classes
|
||
self.nl = len(ch) # number of detection layers
|
||
self.reg_max = 16 # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)
|
||
self.no = nc + self.reg_max * 4 # number of outputs per anchor
|
||
self.stride = torch.zeros(self.nl) # strides computed during build
|
||
c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100)) # channels
|
||
self.cv2 = nn.ModuleList(
|
||
nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch
|
||
)
|
||
self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)
|
||
self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
|
||
|
||
if self.end2end:
|
||
self.one2one_cv2 = copy.deepcopy(self.cv2)
|
||
self.one2one_cv3 = copy.deepcopy(self.cv3)
|
||
|
||
def forward(self, x):
|
||
"""Concatenates and returns predicted bounding boxes and class probabilities."""
|
||
if self.end2end:
|
||
return self.forward_end2end(x)
|
||
|
||
for i in range(self.nl):
|
||
x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
|
||
if self.training: # Training path
|
||
return x
|
||
y = self._inference(x)
|
||
return y if self.export else (y, x)
|
||
|
||
def forward_end2end(self, x):
|
||
"""
|
||
Performs forward pass of the v10Detect module.
|
||
|
||
Args:
|
||
x (tensor): Input tensor.
|
||
|
||
Returns:
|
||
(dict, tensor): If not in training mode, returns a dictionary containing the outputs of both one2many and one2one detections.
|
||
If in training mode, returns a dictionary containing the outputs of one2many and one2one detections separately.
|
||
"""
|
||
x_detach = [xi.detach() for xi in x]
|
||
one2one = [
|
||
torch.cat((self.one2one_cv2[i](x_detach[i]), self.one2one_cv3[i](x_detach[i])), 1) for i in range(self.nl)
|
||
]
|
||
for i in range(self.nl):
|
||
x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
|
||
if self.training: # Training path
|
||
return {"one2many": x, "one2one": one2one}
|
||
|
||
y = self._inference(one2one)
|
||
y = self.postprocess(y.permute(0, 2, 1), self.max_det, self.nc)
|
||
return y if self.export else (y, {"one2many": x, "one2one": one2one})
|
||
|
||
def _inference(self, x):
|
||
"""Decode predicted bounding boxes and class probabilities based on multiple-level feature maps."""
|
||
# Inference path
|
||
shape = x[0].shape # BCHW
|
||
x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
|
||
if self.dynamic or self.shape != shape:
|
||
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
|
||
self.shape = shape
|
||
|
||
if self.export and self.format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}: # avoid TF FlexSplitV ops
|
||
box = x_cat[:, : self.reg_max * 4]
|
||
cls = x_cat[:, self.reg_max * 4 :]
|
||
else:
|
||
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
|
||
|
||
if self.export and self.format in {"tflite", "edgetpu"}:
|
||
# Precompute normalization factor to increase numerical stability
|
||
# See https://github.com/ultralytics/ultralytics/issues/7371
|
||
grid_h = shape[2]
|
||
grid_w = shape[3]
|
||
grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1)
|
||
norm = self.strides / (self.stride[0] * grid_size)
|
||
dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2])
|
||
else:
|
||
dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
|
||
|
||
return torch.cat((dbox, cls.sigmoid()), 1)
|
||
|
||
def bias_init(self):
|
||
"""Initialize Detect() biases, WARNING: requires stride availability."""
|
||
m = self # self.model[-1] # Detect() module
|
||
# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
|
||
# ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
|
||
for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
|
||
a[-1].bias.data[:] = 1.0 # box
|
||
b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
|
||
if self.end2end:
|
||
for a, b, s in zip(m.one2one_cv2, m.one2one_cv3, m.stride): # from
|
||
a[-1].bias.data[:] = 1.0 # box
|
||
b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
|
||
|
||
def decode_bboxes(self, bboxes, anchors):
|
||
"""Decode bounding boxes."""
|
||
return dist2bbox(bboxes, anchors, xywh=not self.end2end, dim=1)
|
||
|
||
@staticmethod
|
||
def postprocess(preds: torch.Tensor, max_det: int, nc: int = 80):
|
||
"""
|
||
Post-processes YOLO model predictions.
|
||
|
||
Args:
|
||
preds (torch.Tensor): Raw predictions with shape (batch_size, num_anchors, 4 + nc) with last dimension
|
||
format [x, y, w, h, class_probs].
|
||
max_det (int): Maximum detections per image.
|
||
nc (int, optional): Number of classes. Default: 80.
|
||
|
||
Returns:
|
||
(torch.Tensor): Processed predictions with shape (batch_size, min(max_det, num_anchors), 6) and last
|
||
dimension format [x, y, w, h, max_class_prob, class_index].
|
||
"""
|
||
batch_size, anchors, _ = preds.shape # i.e. shape(16,8400,84)
|
||
boxes, scores = preds.split([4, nc], dim=-1)
|
||
index = scores.amax(dim=-1).topk(min(max_det, anchors))[1].unsqueeze(-1)
|
||
boxes = boxes.gather(dim=1, index=index.repeat(1, 1, 4))
|
||
scores = scores.gather(dim=1, index=index.repeat(1, 1, nc))
|
||
scores, index = scores.flatten(1).topk(min(max_det, anchors))
|
||
i = torch.arange(batch_size)[..., None] # batch indices
|
||
return torch.cat([boxes[i, index // nc], scores[..., None], (index % nc)[..., None].float()], dim=-1)
|
||
|
||
|
||
class Segment(Detect):
|
||
"""YOLOv8 Segment head for segmentation models."""
|
||
|
||
def __init__(self, nc=80, nm=32, npr=256, ch=()):
|
||
"""Initialize the YOLO model attributes such as the number of masks, prototypes, and the convolution layers."""
|
||
super().__init__(nc, ch)
|
||
self.nm = nm # number of masks
|
||
self.npr = npr # number of protos
|
||
self.proto = Proto(ch[0], self.npr, self.nm) # protos
|
||
|
||
c4 = max(ch[0] // 4, self.nm)
|
||
self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nm, 1)) for x in ch)
|
||
|
||
def forward(self, x):
|
||
"""Return model outputs and mask coefficients if training, otherwise return outputs and mask coefficients."""
|
||
p = self.proto(x[0]) # mask protos
|
||
bs = p.shape[0] # batch size
|
||
|
||
mc = torch.cat([self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients
|
||
x = Detect.forward(self, x)
|
||
if self.training:
|
||
return x, mc, p
|
||
return (torch.cat([x, mc], 1), p) if self.export else (torch.cat([x[0], mc], 1), (x[1], mc, p))
|
||
|
||
|
||
class OBB(Detect):
|
||
"""YOLOv8 OBB detection head for detection with rotation models."""
|
||
|
||
def __init__(self, nc=80, ne=1, ch=()):
|
||
"""Initialize OBB with number of classes `nc` and layer channels `ch`."""
|
||
super().__init__(nc, ch)
|
||
self.ne = ne # number of extra parameters
|
||
|
||
c4 = max(ch[0] // 4, self.ne)
|
||
self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.ne, 1)) for x in ch)
|
||
|
||
def forward(self, x):
|
||
"""Concatenates and returns predicted bounding boxes and class probabilities."""
|
||
bs = x[0].shape[0] # batch size
|
||
angle = torch.cat([self.cv4[i](x[i]).view(bs, self.ne, -1) for i in range(self.nl)], 2) # OBB theta logits
|
||
# NOTE: set `angle` as an attribute so that `decode_bboxes` could use it.
|
||
angle = (angle.sigmoid() - 0.25) * math.pi # [-pi/4, 3pi/4]
|
||
# angle = angle.sigmoid() * math.pi / 2 # [0, pi/2]
|
||
if not self.training:
|
||
self.angle = angle
|
||
x = Detect.forward(self, x)
|
||
if self.training:
|
||
return x, angle
|
||
return torch.cat([x, angle], 1) if self.export else (torch.cat([x[0], angle], 1), (x[1], angle))
|
||
|
||
def decode_bboxes(self, bboxes, anchors):
|
||
"""Decode rotated bounding boxes."""
|
||
return dist2rbox(bboxes, self.angle, anchors, dim=1)
|
||
|
||
|
||
class Pose(Detect):
|
||
"""YOLOv8 Pose head for keypoints models."""
|
||
|
||
def __init__(self, nc=80, kpt_shape=(17, 3), ch=()):
|
||
"""Initialize YOLO network with default parameters and Convolutional Layers."""
|
||
super().__init__(nc, ch)
|
||
self.kpt_shape = kpt_shape # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)
|
||
self.nk = kpt_shape[0] * kpt_shape[1] # number of keypoints total
|
||
|
||
c4 = max(ch[0] // 4, self.nk)
|
||
self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nk, 1)) for x in ch)
|
||
|
||
def forward(self, x):
|
||
"""Perform forward pass through YOLO model and return predictions."""
|
||
bs = x[0].shape[0] # batch size
|
||
kpt = torch.cat([self.cv4[i](x[i]).view(bs, self.nk, -1) for i in range(self.nl)], -1) # (bs, 17*3, h*w)
|
||
x = Detect.forward(self, x)
|
||
if self.training:
|
||
return x, kpt
|
||
pred_kpt = self.kpts_decode(bs, kpt)
|
||
return torch.cat([x, pred_kpt], 1) if self.export else (torch.cat([x[0], pred_kpt], 1), (x[1], kpt))
|
||
|
||
def kpts_decode(self, bs, kpts):
|
||
"""Decodes keypoints."""
|
||
ndim = self.kpt_shape[1]
|
||
if self.export: # required for TFLite export to avoid 'PLACEHOLDER_FOR_GREATER_OP_CODES' bug
|
||
y = kpts.view(bs, *self.kpt_shape, -1)
|
||
a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * self.strides
|
||
if ndim == 3:
|
||
a = torch.cat((a, y[:, :, 2:3].sigmoid()), 2)
|
||
return a.view(bs, self.nk, -1)
|
||
else:
|
||
y = kpts.clone()
|
||
if ndim == 3:
|
||
y[:, 2::3] = y[:, 2::3].sigmoid() # sigmoid (WARNING: inplace .sigmoid_() Apple MPS bug)
|
||
y[:, 0::ndim] = (y[:, 0::ndim] * 2.0 + (self.anchors[0] - 0.5)) * self.strides
|
||
y[:, 1::ndim] = (y[:, 1::ndim] * 2.0 + (self.anchors[1] - 0.5)) * self.strides
|
||
return y
|
||
|
||
|
||
class Classify(nn.Module):
|
||
"""YOLOv8 classification head, i.e. x(b,c1,20,20) to x(b,c2)."""
|
||
|
||
def __init__(self, c1, c2, k=1, s=1, p=None, g=1):
|
||
"""Initializes YOLOv8 classification head to transform input tensor from (b,c1,20,20) to (b,c2) shape."""
|
||
super().__init__()
|
||
c_ = 1280 # efficientnet_b0 size
|
||
self.conv = Conv(c1, c_, k, s, p, g)
|
||
self.pool = nn.AdaptiveAvgPool2d(1) # to x(b,c_,1,1)
|
||
self.drop = nn.Dropout(p=0.0, inplace=True)
|
||
self.linear = nn.Linear(c_, c2) # to x(b,c2)
|
||
|
||
def forward(self, x):
|
||
"""Performs a forward pass of the YOLO model on input image data."""
|
||
if isinstance(x, list):
|
||
x = torch.cat(x, 1)
|
||
x = self.linear(self.drop(self.pool(self.conv(x)).flatten(1)))
|
||
return x if self.training else x.softmax(1)
|
||
|
||
|
||
class WorldDetect(Detect):
|
||
"""Head for integrating YOLOv8 detection models with semantic understanding from text embeddings."""
|
||
|
||
def __init__(self, nc=80, embed=512, with_bn=False, ch=()):
|
||
"""Initialize YOLOv8 detection layer with nc classes and layer channels ch."""
|
||
super().__init__(nc, ch)
|
||
c3 = max(ch[0], min(self.nc, 100))
|
||
self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, embed, 1)) for x in ch)
|
||
self.cv4 = nn.ModuleList(BNContrastiveHead(embed) if with_bn else ContrastiveHead() for _ in ch)
|
||
|
||
def forward(self, x, text):
|
||
"""Concatenates and returns predicted bounding boxes and class probabilities."""
|
||
for i in range(self.nl):
|
||
x[i] = torch.cat((self.cv2[i](x[i]), self.cv4[i](self.cv3[i](x[i]), text)), 1)
|
||
if self.training:
|
||
return x
|
||
|
||
# Inference path
|
||
shape = x[0].shape # BCHW
|
||
x_cat = torch.cat([xi.view(shape[0], self.nc + self.reg_max * 4, -1) for xi in x], 2)
|
||
if self.dynamic or self.shape != shape:
|
||
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
|
||
self.shape = shape
|
||
|
||
if self.export and self.format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}: # avoid TF FlexSplitV ops
|
||
box = x_cat[:, : self.reg_max * 4]
|
||
cls = x_cat[:, self.reg_max * 4 :]
|
||
else:
|
||
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
|
||
|
||
if self.export and self.format in {"tflite", "edgetpu"}:
|
||
# Precompute normalization factor to increase numerical stability
|
||
# See https://github.com/ultralytics/ultralytics/issues/7371
|
||
grid_h = shape[2]
|
||
grid_w = shape[3]
|
||
grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1)
|
||
norm = self.strides / (self.stride[0] * grid_size)
|
||
dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2])
|
||
else:
|
||
dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
|
||
|
||
y = torch.cat((dbox, cls.sigmoid()), 1)
|
||
return y if self.export else (y, x)
|
||
|
||
def bias_init(self):
|
||
"""Initialize Detect() biases, WARNING: requires stride availability."""
|
||
m = self # self.model[-1] # Detect() module
|
||
# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
|
||
# ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
|
||
for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
|
||
a[-1].bias.data[:] = 1.0 # box
|
||
# b[-1].bias.data[:] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
|
||
|
||
|
||
class RTDETRDecoder(nn.Module):
|
||
"""
|
||
Real-Time Deformable Transformer Decoder (RTDETRDecoder) module for object detection.
|
||
|
||
This decoder module utilizes Transformer architecture along with deformable convolutions to predict bounding boxes
|
||
and class labels for objects in an image. It integrates features from multiple layers and runs through a series of
|
||
Transformer decoder layers to output the final predictions.
|
||
"""
|
||
|
||
export = False # export mode
|
||
|
||
def __init__(
|
||
self,
|
||
nc=80,
|
||
ch=(512, 1024, 2048),
|
||
hd=256, # hidden dim
|
||
nq=300, # num queries
|
||
ndp=4, # num decoder points
|
||
nh=8, # num head
|
||
ndl=6, # num decoder layers
|
||
d_ffn=1024, # dim of feedforward
|
||
dropout=0.0,
|
||
act=nn.ReLU(),
|
||
eval_idx=-1,
|
||
# Training args
|
||
nd=100, # num denoising
|
||
label_noise_ratio=0.5,
|
||
box_noise_scale=1.0,
|
||
learnt_init_query=False,
|
||
):
|
||
"""
|
||
Initializes the RTDETRDecoder module with the given parameters.
|
||
|
||
Args:
|
||
nc (int): Number of classes. Default is 80.
|
||
ch (tuple): Channels in the backbone feature maps. Default is (512, 1024, 2048).
|
||
hd (int): Dimension of hidden layers. Default is 256.
|
||
nq (int): Number of query points. Default is 300.
|
||
ndp (int): Number of decoder points. Default is 4.
|
||
nh (int): Number of heads in multi-head attention. Default is 8.
|
||
ndl (int): Number of decoder layers. Default is 6.
|
||
d_ffn (int): Dimension of the feed-forward networks. Default is 1024.
|
||
dropout (float): Dropout rate. Default is 0.
|
||
act (nn.Module): Activation function. Default is nn.ReLU.
|
||
eval_idx (int): Evaluation index. Default is -1.
|
||
nd (int): Number of denoising. Default is 100.
|
||
label_noise_ratio (float): Label noise ratio. Default is 0.5.
|
||
box_noise_scale (float): Box noise scale. Default is 1.0.
|
||
learnt_init_query (bool): Whether to learn initial query embeddings. Default is False.
|
||
"""
|
||
super().__init__()
|
||
self.hidden_dim = hd
|
||
self.nhead = nh
|
||
self.nl = len(ch) # num level
|
||
self.nc = nc
|
||
self.num_queries = nq
|
||
self.num_decoder_layers = ndl
|
||
|
||
# Backbone feature projection
|
||
self.input_proj = nn.ModuleList(nn.Sequential(nn.Conv2d(x, hd, 1, bias=False), nn.BatchNorm2d(hd)) for x in ch)
|
||
# NOTE: simplified version but it's not consistent with .pt weights.
|
||
# self.input_proj = nn.ModuleList(Conv(x, hd, act=False) for x in ch)
|
||
|
||
# Transformer module
|
||
decoder_layer = DeformableTransformerDecoderLayer(hd, nh, d_ffn, dropout, act, self.nl, ndp)
|
||
self.decoder = DeformableTransformerDecoder(hd, decoder_layer, ndl, eval_idx)
|
||
|
||
|
||
|
||
# Denoising part
|
||
self.denoising_class_embed = nn.Embedding(nc, hd)
|
||
self.num_denoising = nd
|
||
self.label_noise_ratio = label_noise_ratio
|
||
self.box_noise_scale = box_noise_scale
|
||
|
||
# Decoder embedding
|
||
self.learnt_init_query = learnt_init_query
|
||
if learnt_init_query:
|
||
self.tgt_embed = nn.Embedding(nq, hd)
|
||
self.query_pos_head = MLP(4, 2 * hd, hd, num_layers=2)
|
||
|
||
# Encoder head
|
||
self.enc_output = nn.Sequential(nn.Linear(hd, hd), nn.LayerNorm(hd))
|
||
self.enc_score_head = nn.Linear(hd, nc)
|
||
self.enc_bbox_head = MLP(hd, hd, 4, num_layers=3)
|
||
|
||
# Decoder head
|
||
self.dec_score_head = nn.ModuleList([nn.Linear(hd, nc) for _ in range(ndl)])
|
||
self.dec_bbox_head = nn.ModuleList([MLP(hd, hd, 4, num_layers=3) for _ in range(ndl)])
|
||
|
||
self._reset_parameters()
|
||
|
||
def forward(self, x, batch=None):
|
||
"""Runs the forward pass of the module, returning bounding box and classification scores for the input."""
|
||
from ultralytics.models.utils.ops import get_cdn_group
|
||
|
||
# Input projection and embedding
|
||
feats, shapes = self._get_encoder_input(x)
|
||
|
||
# Prepare denoising training
|
||
dn_embed, dn_bbox, attn_mask, dn_meta = get_cdn_group(
|
||
batch,
|
||
self.nc,
|
||
self.num_queries,
|
||
self.denoising_class_embed.weight,
|
||
self.num_denoising,
|
||
self.label_noise_ratio,
|
||
self.box_noise_scale,
|
||
self.training,
|
||
)
|
||
|
||
embed, refer_bbox, enc_bboxes, enc_scores = self._get_decoder_input(feats, shapes, dn_embed, dn_bbox)
|
||
|
||
# Decoder
|
||
dec_bboxes, dec_scores = self.decoder(
|
||
embed,
|
||
refer_bbox,
|
||
feats,
|
||
shapes,
|
||
self.dec_bbox_head,
|
||
self.dec_score_head,
|
||
self.query_pos_head,
|
||
attn_mask=attn_mask,
|
||
)
|
||
x = dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta
|
||
if self.training:
|
||
return x
|
||
# (bs, 300, 4+nc)
|
||
y = torch.cat((dec_bboxes.squeeze(0), dec_scores.squeeze(0).sigmoid()), -1)
|
||
return y if self.export else (y, x)
|
||
|
||
def _generate_anchors(self, shapes, grid_size=0.05, dtype=torch.float32, device="cpu", eps=1e-2):
|
||
"""Generates anchor bounding boxes for given shapes with specific grid size and validates them."""
|
||
anchors = []
|
||
for i, (h, w) in enumerate(shapes):
|
||
sy = torch.arange(end=h, dtype=dtype, device=device)
|
||
sx = torch.arange(end=w, dtype=dtype, device=device)
|
||
grid_y, grid_x = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx)
|
||
grid_xy = torch.stack([grid_x, grid_y], -1) # (h, w, 2)
|
||
|
||
valid_WH = torch.tensor([w, h], dtype=dtype, device=device)
|
||
grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH # (1, h, w, 2)
|
||
wh = torch.ones_like(grid_xy, dtype=dtype, device=device) * grid_size * (2.0**i)
|
||
anchors.append(torch.cat([grid_xy, wh], -1).view(-1, h * w, 4)) # (1, h*w, 4)
|
||
|
||
anchors = torch.cat(anchors, 1) # (1, h*w*nl, 4)
|
||
valid_mask = ((anchors > eps) & (anchors < 1 - eps)).all(-1, keepdim=True) # 1, h*w*nl, 1
|
||
anchors = torch.log(anchors / (1 - anchors))
|
||
anchors = anchors.masked_fill(~valid_mask, float("inf"))
|
||
return anchors, valid_mask
|
||
|
||
def _get_encoder_input(self, x):
|
||
"""Processes and returns encoder inputs by getting projection features from input and concatenating them."""
|
||
# Get projection features
|
||
x = [self.input_proj[i](feat) for i, feat in enumerate(x)]
|
||
# Get encoder inputs
|
||
feats = []
|
||
shapes = []
|
||
for feat in x:
|
||
h, w = feat.shape[2:]
|
||
# [b, c, h, w] -> [b, h*w, c]
|
||
feats.append(feat.flatten(2).permute(0, 2, 1))
|
||
# [nl, 2]
|
||
shapes.append([h, w])
|
||
|
||
# [b, h*w, c]
|
||
feats = torch.cat(feats, 1)
|
||
return feats, shapes
|
||
|
||
def _get_decoder_input(self, feats, shapes, dn_embed=None, dn_bbox=None):
|
||
"""Generates and prepares the input required for the decoder from the provided features and shapes."""
|
||
bs = feats.shape[0]
|
||
# Prepare input for decoder
|
||
anchors, valid_mask = self._generate_anchors(shapes, dtype=feats.dtype, device=feats.device)
|
||
features = self.enc_output(valid_mask * feats) # bs, h*w, 256
|
||
|
||
enc_outputs_scores = self.enc_score_head(features) # (bs, h*w, nc)
|
||
|
||
# Query selection
|
||
# (bs, num_queries)
|
||
topk_ind = torch.topk(enc_outputs_scores.max(-1).values, self.num_queries, dim=1).indices.view(-1)
|
||
# (bs, num_queries)
|
||
batch_ind = torch.arange(end=bs, dtype=topk_ind.dtype).unsqueeze(-1).repeat(1, self.num_queries).view(-1)
|
||
|
||
# (bs, num_queries, 256)
|
||
top_k_features = features[batch_ind, topk_ind].view(bs, self.num_queries, -1)
|
||
# (bs, num_queries, 4)
|
||
top_k_anchors = anchors[:, topk_ind].view(bs, self.num_queries, -1)
|
||
|
||
# Dynamic anchors + static content
|
||
refer_bbox = self.enc_bbox_head(top_k_features) + top_k_anchors
|
||
|
||
enc_bboxes = refer_bbox.sigmoid()
|
||
if dn_bbox is not None:
|
||
refer_bbox = torch.cat([dn_bbox, refer_bbox], 1)
|
||
enc_scores = enc_outputs_scores[batch_ind, topk_ind].view(bs, self.num_queries, -1)
|
||
|
||
embeddings = self.tgt_embed.weight.unsqueeze(0).repeat(bs, 1, 1) if self.learnt_init_query else top_k_features
|
||
if self.training:
|
||
refer_bbox = refer_bbox.detach()
|
||
if not self.learnt_init_query:
|
||
embeddings = embeddings.detach()
|
||
if dn_embed is not None:
|
||
embeddings = torch.cat([dn_embed, embeddings], 1)
|
||
|
||
return embeddings, refer_bbox, enc_bboxes, enc_scores
|
||
|
||
# TODO
|
||
def _reset_parameters(self):
|
||
"""Initializes or resets the parameters of the model's various components with predefined weights and biases."""
|
||
# Class and bbox head init
|
||
bias_cls = bias_init_with_prob(0.01) / 80 * self.nc
|
||
# NOTE: the weight initialization in `linear_init` would cause NaN when training with custom datasets.
|
||
# linear_init(self.enc_score_head)
|
||
constant_(self.enc_score_head.bias, bias_cls)
|
||
constant_(self.enc_bbox_head.layers[-1].weight, 0.0)
|
||
constant_(self.enc_bbox_head.layers[-1].bias, 0.0)
|
||
for cls_, reg_ in zip(self.dec_score_head, self.dec_bbox_head):
|
||
# linear_init(cls_)
|
||
constant_(cls_.bias, bias_cls)
|
||
constant_(reg_.layers[-1].weight, 0.0)
|
||
constant_(reg_.layers[-1].bias, 0.0)
|
||
|
||
linear_init(self.enc_output[0])
|
||
xavier_uniform_(self.enc_output[0].weight)
|
||
if self.learnt_init_query:
|
||
xavier_uniform_(self.tgt_embed.weight)
|
||
xavier_uniform_(self.query_pos_head.layers[0].weight)
|
||
xavier_uniform_(self.query_pos_head.layers[1].weight)
|
||
for layer in self.input_proj:
|
||
xavier_uniform_(layer[0].weight)
|
||
|
||
|
||
class v10Detect(Detect):
|
||
"""
|
||
v10 Detection head from https://arxiv.org/pdf/2405.14458.
|
||
|
||
Args:
|
||
nc (int): Number of classes.
|
||
ch (tuple): Tuple of channel sizes.
|
||
|
||
Attributes:
|
||
max_det (int): Maximum number of detections.
|
||
|
||
Methods:
|
||
__init__(self, nc=80, ch=()): Initializes the v10Detect object.
|
||
forward(self, x): Performs forward pass of the v10Detect module.
|
||
bias_init(self): Initializes biases of the Detect module.
|
||
|
||
"""
|
||
|
||
end2end = True
|
||
|
||
def __init__(self, nc=80, ch=()):
|
||
"""Initializes the v10Detect object with the specified number of classes and input channels."""
|
||
super().__init__(nc, ch)
|
||
c3 = max(ch[0], min(self.nc, 100)) # channels
|
||
# Light cls head
|
||
self.cv3 = nn.ModuleList(
|
||
nn.Sequential(
|
||
nn.Sequential(Conv(x, x, 3, g=x), Conv(x, c3, 1)),
|
||
nn.Sequential(Conv(c3, c3, 3, g=c3), Conv(c3, c3, 1)),
|
||
nn.Conv2d(c3, self.nc, 1),
|
||
)
|
||
for x in ch
|
||
)
|
||
self.one2one_cv3 = copy.deepcopy(self.cv3)
|
||
|
||
|
||
|
||
class MTDETRDecoder(nn.Module):
|
||
"""
|
||
Real-Time Deformable Transformer Decoder (RTDETRDecoder) module for object detection and semantic segmentation.
|
||
|
||
This decoder module utilizes Transformer architecture along with deformable convolutions to predict bounding boxes,
|
||
class labels for objects, and semantic segmentation masks following MaskFormer's approach. It integrates features
|
||
from multiple layers and runs through a series of Transformer decoder layers to output the final predictions.
|
||
"""
|
||
|
||
export = False # export mode
|
||
|
||
def __init__(
|
||
self,
|
||
nc=80,
|
||
ch=(512, 1024, 2048),
|
||
ns_classes=None, # Number of semantic segmentation classes
|
||
hd=256, # hidden dim
|
||
nq=300, # num queries
|
||
ndp=4, # num decoder points
|
||
nh=8, # num head
|
||
ndl=6, # num decoder layers
|
||
d_ffn=1024, # dim of feedforward
|
||
dropout=0.0,
|
||
act=nn.ReLU(),
|
||
eval_idx=-1,
|
||
# Training args
|
||
nd=100, # num denoising
|
||
label_noise_ratio=0.5,
|
||
box_noise_scale=1.0,
|
||
learnt_init_query=False,
|
||
):
|
||
"""
|
||
Initializes the RTDETRDecoder module with the given parameters.
|
||
|
||
Args:
|
||
nc (int): Number of classes for object detection. Default is 80.
|
||
ch (tuple): Channels in the backbone feature maps. Default is (512, 1024, 2048).
|
||
hd (int): Dimension of hidden layers. Default is 256.
|
||
nq (int): Number of query points. Default is 300.
|
||
ndp (int): Number of decoder points. Default is 4.
|
||
nh (int): Number of heads in multi-head attention. Default is 8.
|
||
ndl (int): Number of decoder layers. Default is 6.
|
||
d_ffn (int): Dimension of the feed-forward networks. Default is 1024.
|
||
dropout (float): Dropout rate. Default is 0.
|
||
act (nn.Module): Activation function. Default is nn.ReLU.
|
||
eval_idx (int): Evaluation index. Default is -1.
|
||
nd (int): Number of denoising. Default is 100.
|
||
label_noise_ratio (float): Label noise ratio. Default is 0.5.
|
||
box_noise_scale (float): Box noise scale. Default is 1.0.
|
||
learnt_init_query (bool): Whether to learn initial query embeddings. Default is False.
|
||
num_seg_classes (int): Number of classes for semantic segmentation. Default is 21.
|
||
"""
|
||
super().__init__()
|
||
self.hidden_dim = hd
|
||
self.nhead = nh
|
||
self.nl = len(ch) # num level
|
||
self.nc = nc
|
||
self.num_queries = nq
|
||
self.num_decoder_layers = ndl
|
||
|
||
# Backbone feature projection
|
||
self.input_proj = nn.ModuleList(nn.Sequential(nn.Conv2d(x, hd, 1, bias=False), nn.BatchNorm2d(hd)) for x in ch)
|
||
# self.input_proj = nn.ModuleList(Conv(x, hd, act=False) for x in ch)
|
||
|
||
# Transformer module
|
||
decoder_layer = DeformableTransformerDecoderLayer(hd, nh, d_ffn, dropout, act, self.nl, ndp)
|
||
self.decoder = DeformableTransformerDecoder(hd, decoder_layer, ndl, eval_idx)
|
||
# self.decoder = DeformableTransformerDecoder_withseg(hd, decoder_layer, ndl, eval_idx)
|
||
|
||
# Denoising part
|
||
self.denoising_class_embed = nn.Embedding(nc, hd)
|
||
self.num_denoising = nd
|
||
self.label_noise_ratio = label_noise_ratio
|
||
self.box_noise_scale = box_noise_scale
|
||
|
||
# Decoder embedding
|
||
self.learnt_init_query = learnt_init_query
|
||
if learnt_init_query:
|
||
self.tgt_embed = nn.Embedding(nq, hd)
|
||
self.query_pos_head = MLP(4, 2 * hd, hd, num_layers=2)
|
||
|
||
# Encoder head
|
||
self.enc_output = nn.Sequential(nn.Linear(hd, hd), nn.LayerNorm(hd))
|
||
self.enc_score_head = nn.Linear(hd, nc)
|
||
self.enc_bbox_head = MLP(hd, hd, 4, num_layers=3)
|
||
|
||
# Decoder head
|
||
self.dec_score_head = nn.ModuleList([nn.Linear(hd, nc) for _ in range(ndl)])
|
||
self.dec_bbox_head = nn.ModuleList([MLP(hd, hd, 4, num_layers=3) for _ in range(ndl)])
|
||
|
||
# Semantic segmentation head
|
||
self.imgsz = 640
|
||
self.seg_head = TransformerSegmentationDecoder(hd, ns_classes)
|
||
|
||
self.task_adapter_det = nn.ModuleList(
|
||
[TaskAdapterLite(hd) for _ in ch]
|
||
)
|
||
self.task_adapter_seg = nn.ModuleList(
|
||
[TaskAdapterLite(hd) for _ in ch]
|
||
)
|
||
|
||
# dynamic gate
|
||
self.gate_det = nn.ModuleList([
|
||
LiteDynamicGate(hd, reduction=8) for _ in ch
|
||
])
|
||
self.gate_seg = nn.ModuleList([
|
||
LiteDynamicGate(hd, reduction=8) for _ in ch
|
||
])
|
||
|
||
self._reset_parameters()
|
||
|
||
def forward(self, x, batch=None):
|
||
"""Runs the forward pass of the module, returning bounding box, classification scores, and segmentation masks."""
|
||
from ultralytics.models.utils.ops import get_cdn_group
|
||
# Input projection and embedding
|
||
x_proj, feats_flat, shapes, gate_mean = self._get_encoder_input(x) # x_proj is CCF output and has the same number of channels. feats_flat is a fusion feature and after flatten.
|
||
|
||
# Prepare denoising training
|
||
dn_embed, dn_bbox, attn_mask, dn_meta = get_cdn_group(
|
||
batch,
|
||
self.nc,
|
||
self.num_queries,
|
||
self.denoising_class_embed.weight,
|
||
self.num_denoising,
|
||
self.label_noise_ratio,
|
||
self.box_noise_scale,
|
||
self.training,
|
||
)
|
||
|
||
embed, refer_bbox, enc_bboxes, enc_scores = self._get_decoder_input(feats_flat, shapes, dn_embed, dn_bbox)
|
||
|
||
# segmentation decoder
|
||
seg_mask, aux_list = self.seg_head(x_proj, self.imgsz)
|
||
|
||
# Decoder
|
||
dec_bboxes, dec_scores = self.decoder(
|
||
embed,
|
||
refer_bbox,
|
||
feats_flat,
|
||
shapes,
|
||
self.dec_bbox_head,
|
||
self.dec_score_head,
|
||
self.query_pos_head,
|
||
attn_mask=attn_mask,
|
||
)
|
||
|
||
x = dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta, seg_mask, gate_mean
|
||
if self.training:
|
||
return x
|
||
# (bs, 300, 4+nc)
|
||
y = torch.cat((dec_bboxes.squeeze(0), dec_scores.squeeze(0).sigmoid()), -1)
|
||
return y if self.export else (y, x) ### TODO need to modify the logic of self.export = True and y need to include the seg_masks
|
||
|
||
def _generate_anchors(self, shapes, grid_size=0.05, dtype=torch.float32, device="cpu", eps=1e-2):
|
||
"""Generates anchor bounding boxes for given shapes with specific grid size and validates them."""
|
||
anchors = []
|
||
for i, (h, w) in enumerate(shapes):
|
||
sy = torch.arange(end=h, dtype=dtype, device=device)
|
||
sx = torch.arange(end=w, dtype=dtype, device=device)
|
||
grid_y, grid_x = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx)
|
||
grid_xy = torch.stack([grid_x, grid_y], -1) # (h, w, 2)
|
||
|
||
valid_WH = torch.tensor([w, h], dtype=dtype, device=device)
|
||
grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH # (1, h, w, 2)
|
||
wh = torch.ones_like(grid_xy, dtype=dtype, device=device) * grid_size * (2.0**i)
|
||
anchors.append(torch.cat([grid_xy, wh], -1).view(-1, h * w, 4)) # (1, h*w, 4)
|
||
|
||
anchors = torch.cat(anchors, 1) # (1, h*w*nl, 4)
|
||
valid_mask = ((anchors > eps) & (anchors < 1 - eps)).all(-1, keepdim=True) # 1, h*w*nl, 1
|
||
anchors = torch.log(anchors / (1 - anchors))
|
||
anchors = anchors.masked_fill(~valid_mask, float("inf"))
|
||
return anchors, valid_mask
|
||
|
||
# def _get_encoder_input(self, x):
|
||
# """Processes and returns encoder inputs by getting projection features from input and concatenating them."""
|
||
# # Get projection features
|
||
# x_proj = [self.input_proj[i](feat) for i, feat in enumerate(x)]
|
||
# # Get encoder inputs
|
||
# feats = []
|
||
# shapes = []
|
||
# for feat in x_proj:
|
||
# h, w = feat.shape[2:]
|
||
# # [b, c, h, w] -> [b, h*w, c]
|
||
# feats.append(feat.flatten(2).permute(0, 2, 1))
|
||
# # [nl, 2]
|
||
# shapes.append([h, w])
|
||
#
|
||
# # [b, h*w, c]
|
||
# feats_flat = torch.cat(feats, 1)
|
||
# return x_proj, feats_flat, shapes, [1,1]
|
||
|
||
def _get_encoder_input(self, x):
|
||
"""Processes and returns encoder inputs by getting projection features from input and concatenating them."""
|
||
|
||
x_proj_det = []
|
||
x_proj_seg = []
|
||
gate_mean_det_list = []
|
||
gate_mean_seg_list = []
|
||
|
||
shapes = []
|
||
|
||
for i, feat in enumerate(x):
|
||
proj_feat = self.input_proj[i](feat)
|
||
|
||
# Object detection gate
|
||
det_feat = self.task_adapter_det[i](proj_feat)
|
||
gated_det, gate_mean_det = self.gate_det[i](proj_feat, det_feat)
|
||
# self.merge_topk_heatmap(gate_mean_det[0], K=10, mode='mean')
|
||
# self.plot_gate_heatmap(gate_mean_det, sample_idx=0)
|
||
# self.plot_diff_distribution(proj_feat, det_feat, sample_idx=0)
|
||
x_proj_det.append(gated_det)
|
||
gate_mean_det_list.append(gate_mean_det.mean().item())
|
||
|
||
# Segmentation gate
|
||
seg_feat = self.task_adapter_seg[i](proj_feat)
|
||
gated_seg, gate_mean_seg = self.gate_seg[i](proj_feat, seg_feat)
|
||
# self.merge_topk_heatmap(gate_mean_seg[0], K=10, mode='mean')
|
||
# self.plot_gate_heatmap(gate_mean_seg, sample_idx=0)
|
||
# self.plot_diff_distribution(proj_feat, seg_feat, sample_idx=0)
|
||
x_proj_seg.append(gated_seg)
|
||
gate_mean_seg_list.append(gate_mean_seg.mean().item())
|
||
|
||
# D_det = (det_feat - proj_feat).flatten(2).norm(dim=2).mean().item()
|
||
# D_seg = (seg_feat - proj_feat).flatten(2).norm(dim=2).mean().item()
|
||
# D_shared = proj_feat.flatten(2).norm(dim=2).mean().item()
|
||
# print(f"Residual ‖det–shared‖/D_shared={D_det/D_shared:.4f}, ‖seg–shared‖/D_shared={D_seg/D_shared:.4f}")
|
||
|
||
h, w = proj_feat.shape[2:]
|
||
shapes.append([h, w])
|
||
|
||
feats_flat_det = torch.cat([f.flatten(2).permute(0, 2, 1) for f in x_proj_det], 1)
|
||
|
||
return x_proj_seg, feats_flat_det, shapes, [np.mean(gate_mean_det_list), np.mean(gate_mean_seg_list)]
|
||
|
||
def merge_topk_heatmap(self, gate_map: torch.Tensor,
|
||
K: int = 5,
|
||
mode: str = 'mean',
|
||
normalize: bool = True) -> torch.Tensor:
|
||
if gate_map.dim() == 4:
|
||
G = gate_map[0]
|
||
else:
|
||
G = gate_map # [C,H,W]
|
||
|
||
C, H, W = G.shape
|
||
var_per_ch = G.view(C, -1).var(dim=1) # [C]
|
||
topk_idx = torch.topk(var_per_ch, k=K).indices # [K]
|
||
|
||
selected = G[topk_idx] # [K, H, W]
|
||
if mode == 'sum':
|
||
merged = selected.sum(dim=0) # [H, W]
|
||
else: # 'mean'
|
||
merged = selected.mean(dim=0) # [H, W]
|
||
|
||
if normalize:
|
||
mn, mx = merged.min(), merged.max()
|
||
if mx > mn:
|
||
merged = (merged - mn) / (mx - mn)
|
||
plt.figure(figsize=(4, 4))
|
||
plt.imshow(merged.cpu().numpy(), cmap='viridis')
|
||
plt.axis('off')
|
||
plt.show()
|
||
|
||
def plot_gate_heatmap(self, gate: torch.Tensor, sample_idx=0, channel_idx=None):
|
||
# gate: [B,C,H,W]
|
||
G = gate[sample_idx].cpu().numpy() # [C,H,W]
|
||
C, H, W = G.shape
|
||
if channel_idx is None:
|
||
var = G.reshape(C, -1).var(axis=1)
|
||
channel_idx = int(var.argmax())
|
||
heatmap = G[channel_idx]
|
||
plt.figure(figsize=(4, 4))
|
||
plt.imshow(heatmap, cmap='viridis')
|
||
plt.title(f'Gate Heatmap (ch={channel_idx})')
|
||
plt.colorbar()
|
||
plt.axis('off')
|
||
plt.show()
|
||
|
||
def plot_diff_distribution(self, shared: torch.Tensor, task: torch.Tensor, sample_idx=0):
|
||
D = (task - shared)[sample_idx].cpu().numpy() # [C,H,W]
|
||
flat = D.reshape(D.shape[0], -1)
|
||
l2 = np.linalg.norm(flat, axis=0)
|
||
plt.figure(figsize=(5, 3))
|
||
plt.hist(l2, bins=100, color='gray', edgecolor='black')
|
||
plt.title('Distribution of ||task - shared||₂')
|
||
plt.xlabel('L2 norm')
|
||
plt.ylabel('Freq.')
|
||
plt.show()
|
||
|
||
def _get_decoder_input(self, feats, shapes, dn_embed=None, dn_bbox=None):
|
||
"""Generates and prepares the input required for the decoder from the provided features and shapes."""
|
||
bs = feats.shape[0]
|
||
# Prepare input for decoder
|
||
anchors, valid_mask = self._generate_anchors(shapes, dtype=feats.dtype, device=feats.device)
|
||
features = self.enc_output(valid_mask * feats) # bs, h*w, 256
|
||
|
||
enc_outputs_scores = self.enc_score_head(features) # (bs, h*w, nc)
|
||
|
||
# Query selection
|
||
# (bs, num_queries)
|
||
topk_ind = torch.topk(enc_outputs_scores.max(-1).values, self.num_queries, dim=1).indices.view(-1)
|
||
# (bs, num_queries)
|
||
batch_ind = torch.arange(end=bs, dtype=topk_ind.dtype).unsqueeze(-1).repeat(1, self.num_queries).view(-1)
|
||
|
||
# (bs, num_queries, 256)
|
||
top_k_features = features[batch_ind, topk_ind].view(bs, self.num_queries, -1)
|
||
# (bs, num_queries, 4)
|
||
top_k_anchors = anchors[:, topk_ind].view(bs, self.num_queries, -1)
|
||
|
||
# Dynamic anchors + static content
|
||
refer_bbox = self.enc_bbox_head(top_k_features) + top_k_anchors
|
||
|
||
enc_bboxes = refer_bbox.sigmoid()
|
||
if dn_bbox is not None:
|
||
refer_bbox = torch.cat([dn_bbox, refer_bbox], 1)
|
||
enc_scores = enc_outputs_scores[batch_ind, topk_ind].view(bs, self.num_queries, -1)
|
||
|
||
embeddings = self.tgt_embed.weight.unsqueeze(0).repeat(bs, 1, 1) if self.learnt_init_query else top_k_features
|
||
if self.training:
|
||
refer_bbox = refer_bbox.detach()
|
||
if not self.learnt_init_query:
|
||
embeddings = embeddings.detach()
|
||
if dn_embed is not None:
|
||
embeddings = torch.cat([dn_embed, embeddings], 1)
|
||
|
||
return embeddings, refer_bbox, enc_bboxes, enc_scores
|
||
|
||
# TODO
|
||
def _reset_parameters(self):
|
||
"""Initializes or resets the parameters of the model's various components with predefined weights and biases."""
|
||
# Class and bbox head init
|
||
bias_cls = bias_init_with_prob(0.01) / 80 * self.nc
|
||
# NOTE: the weight initialization in `linear_init` would cause NaN when training with custom datasets.
|
||
# linear_init(self.enc_score_head)
|
||
constant_(self.enc_score_head.bias, bias_cls)
|
||
constant_(self.enc_bbox_head.layers[-1].weight, 0.0)
|
||
constant_(self.enc_bbox_head.layers[-1].bias, 0.0)
|
||
for cls_, reg_ in zip(self.dec_score_head, self.dec_bbox_head):
|
||
# linear_init(cls_)
|
||
constant_(cls_.bias, bias_cls)
|
||
constant_(reg_.layers[-1].weight, 0.0)
|
||
constant_(reg_.layers[-1].bias, 0.0)
|
||
|
||
linear_init(self.enc_output[0])
|
||
xavier_uniform_(self.enc_output[0].weight)
|
||
if self.learnt_init_query:
|
||
xavier_uniform_(self.tgt_embed.weight)
|
||
xavier_uniform_(self.query_pos_head.layers[0].weight)
|
||
xavier_uniform_(self.query_pos_head.layers[1].weight)
|
||
for layer in self.input_proj:
|
||
xavier_uniform_(layer[0].weight)
|
||
|
||
|
||
class TaskAdapterLite(nn.Module):
|
||
def __init__(self, dim):
|
||
super().__init__()
|
||
self.conv1 = nn.Sequential(
|
||
nn.Conv2d(dim, dim, kernel_size=1, bias=False),
|
||
nn.BatchNorm2d(dim),
|
||
nn.SiLU()
|
||
)
|
||
self.conv3 = nn.Sequential(
|
||
nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim, bias=False),
|
||
nn.Conv2d(dim, dim, kernel_size=1, bias=False),
|
||
nn.BatchNorm2d(dim),
|
||
nn.SiLU()
|
||
)
|
||
self._init_weights()
|
||
|
||
def _init_weights(self):
|
||
for m in self.modules():
|
||
if isinstance(m, nn.Conv2d):
|
||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||
if m.bias is not None:
|
||
nn.init.zeros_(m.bias)
|
||
|
||
def forward(self, x):
|
||
y = self.conv1(x)
|
||
y = self.conv3(y)
|
||
return y
|
||
|
||
|
||
class LiteDynamicGate(nn.Module):
|
||
|
||
def __init__(self, in_dim, reduction=16,
|
||
clamp_min=0.05, clamp_max=0.95):
|
||
super().__init__()
|
||
self.clamp_min = clamp_min
|
||
self.clamp_max = clamp_max
|
||
cat_dim = in_dim * 2
|
||
|
||
self.channel_att = nn.Sequential(
|
||
nn.AdaptiveAvgPool2d(1),
|
||
nn.Conv2d(cat_dim, in_dim // reduction, 1, bias=True),
|
||
nn.BatchNorm2d(in_dim // reduction),
|
||
nn.ReLU(inplace=True),
|
||
nn.Conv2d(in_dim // reduction, in_dim, 1, bias=True),
|
||
nn.Sigmoid()
|
||
)
|
||
|
||
self.spatial_att = nn.Sequential(
|
||
nn.Conv2d(cat_dim, cat_dim, kernel_size=3, padding=1, groups=cat_dim),
|
||
nn.BatchNorm2d(cat_dim),
|
||
nn.ReLU(inplace=True),
|
||
nn.Conv2d(cat_dim, in_dim, kernel_size=1, bias=True),
|
||
nn.Sigmoid()
|
||
)
|
||
|
||
self.alpha_net = nn.Sequential(
|
||
nn.Conv2d(cat_dim, in_dim, 3, padding=1, groups=in_dim),
|
||
nn.BatchNorm2d(in_dim),
|
||
nn.ReLU(inplace=True),
|
||
nn.AdaptiveAvgPool2d(1), # [B, C, 1, 1]
|
||
nn.Conv2d(in_dim, in_dim, 1, bias=True),
|
||
nn.BatchNorm2d(in_dim),
|
||
nn.Sigmoid()
|
||
)
|
||
self._init_weights()
|
||
|
||
def _init_weights(self):
|
||
for m in self.modules():
|
||
if isinstance(m, nn.Conv2d):
|
||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||
if m.bias is not None:
|
||
nn.init.normal_(m.bias, mean=0.0, std=1e-2)
|
||
|
||
|
||
def forward(self, shared: torch.Tensor, task: torch.Tensor):
|
||
"""
|
||
input:
|
||
shared: [B, C, H, W]
|
||
task: [B, C, H, W]
|
||
output:
|
||
out: [B, C, H, W]
|
||
gate: [B, C, H, W]
|
||
"""
|
||
x = torch.cat([shared, task], dim=1) # [B, 2C, H, W]
|
||
|
||
c_gate = self.channel_att(x) # [B, C, 1, 1]
|
||
s_gate = self.spatial_att(x) # [B, C, H, W]
|
||
alpha_map = self.alpha_net(x) # [B, C, 1, 1]
|
||
|
||
gate = alpha_map * c_gate + (1 - alpha_map) * s_gate
|
||
gate = gate.clamp(self.clamp_min, self.clamp_max)
|
||
|
||
out = shared + gate * (task - shared)
|
||
|
||
return out, gate
|