[Major] Add FLOPs counter.
This commit is contained in:
parent
c91bbb8841
commit
86b077bea0
|
|
@ -0,0 +1,126 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from mmdet.models.backbones.swin import WindowMSA, ShiftWindowMSA
|
||||||
|
from mmdet3d.ops.spconv import SparseConv3d, SubMConv3d
|
||||||
|
from mmdet3d.models.utils.transformer import MultiheadAttention
|
||||||
|
from typing import Union
|
||||||
|
from thop import profile
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["flops_counter"]
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: no need to consider ShiftWindowMSA since it contains WindowMSA
|
||||||
|
def count_window_msa(m: Union[WindowMSA, ShiftWindowMSA], x, y):
|
||||||
|
if isinstance(m, WindowMSA):
|
||||||
|
embed_dims = m.embed_dims
|
||||||
|
num_heads = m.num_heads
|
||||||
|
else:
|
||||||
|
embed_dims = m.w_msa.embed_dims
|
||||||
|
num_heads = m.w_msa.num_heads
|
||||||
|
B, N, C = x[0].shape
|
||||||
|
# qkv = model.qkv(x)
|
||||||
|
m.total_ops += B * N * embed_dims * 3 * embed_dims
|
||||||
|
# attn = (q @ k.transpose(-2, -1))
|
||||||
|
m.total_ops += B * num_heads * N * (embed_dims // num_heads) * N
|
||||||
|
# x = (attn @ v)
|
||||||
|
m.total_ops += num_heads * B * N * N * (embed_dims // num_heads)
|
||||||
|
# x = m.proj(x)
|
||||||
|
m.total_ops += B * N * embed_dims * embed_dims
|
||||||
|
|
||||||
|
|
||||||
|
def count_sparseconv(m: Union[SparseConv3d, SubMConv3d], x, y):
|
||||||
|
indice_dict = y.indice_dict[m.indice_key]
|
||||||
|
kmap_size = indice_dict[-2].sum().item()
|
||||||
|
m.total_ops += kmap_size * x[0].features.shape[1] * y.features.shape[1]
|
||||||
|
|
||||||
|
|
||||||
|
def count_mha(m: Union[MultiheadAttention, nn.MultiheadAttention], x, y):
|
||||||
|
flops = 0
|
||||||
|
if len(x) == 3:
|
||||||
|
q, k, v = x
|
||||||
|
elif len(x) == 2:
|
||||||
|
q, k = x
|
||||||
|
v = k
|
||||||
|
elif len(x) == 1:
|
||||||
|
q = x[0]
|
||||||
|
k = v = q
|
||||||
|
else:
|
||||||
|
return
|
||||||
|
|
||||||
|
batch_first = m.batch_first \
|
||||||
|
if hasattr(m, 'batch_first') else False
|
||||||
|
if batch_first:
|
||||||
|
batch_size = q.shape[0]
|
||||||
|
len_idx = 1
|
||||||
|
else:
|
||||||
|
batch_size = q.shape[1]
|
||||||
|
len_idx = 0
|
||||||
|
|
||||||
|
dim_idx = 2
|
||||||
|
|
||||||
|
qdim = q.shape[dim_idx]
|
||||||
|
kdim = k.shape[dim_idx]
|
||||||
|
vdim = v.shape[dim_idx]
|
||||||
|
|
||||||
|
qlen = q.shape[len_idx]
|
||||||
|
klen = k.shape[len_idx]
|
||||||
|
vlen = v.shape[len_idx]
|
||||||
|
|
||||||
|
num_heads = m.num_heads
|
||||||
|
assert qdim == m.embed_dim
|
||||||
|
|
||||||
|
if m.kdim is None:
|
||||||
|
assert kdim == qdim
|
||||||
|
if m.vdim is None:
|
||||||
|
assert vdim == qdim
|
||||||
|
|
||||||
|
flops = 0
|
||||||
|
|
||||||
|
# Q scaling
|
||||||
|
flops += qlen * qdim
|
||||||
|
|
||||||
|
# Initial projections
|
||||||
|
flops += (
|
||||||
|
(qlen * qdim * qdim) # QW
|
||||||
|
+ (klen * kdim * kdim) # KW
|
||||||
|
+ (vlen * vdim * vdim) # VW
|
||||||
|
)
|
||||||
|
|
||||||
|
if m.in_proj_bias is not None:
|
||||||
|
flops += (qlen + klen + vlen) * qdim
|
||||||
|
|
||||||
|
# attention heads: scale, matmul, softmax, matmul
|
||||||
|
qk_head_dim = qdim // num_heads
|
||||||
|
v_head_dim = vdim // num_heads
|
||||||
|
|
||||||
|
head_flops = (
|
||||||
|
(qlen * klen * qk_head_dim) # QK^T
|
||||||
|
+ (qlen * klen) # softmax
|
||||||
|
+ (qlen * klen * v_head_dim) # AV
|
||||||
|
)
|
||||||
|
|
||||||
|
flops += num_heads * head_flops
|
||||||
|
|
||||||
|
# final projection, bias is always enabled
|
||||||
|
flops += qlen * vdim * (vdim + 1)
|
||||||
|
|
||||||
|
flops *= batch_size
|
||||||
|
m.total_ops += flops
|
||||||
|
|
||||||
|
|
||||||
|
def flops_counter(model, inputs):
|
||||||
|
macs, params = profile(
|
||||||
|
model,
|
||||||
|
inputs,
|
||||||
|
custom_ops={
|
||||||
|
WindowMSA: count_window_msa,
|
||||||
|
#ShiftWindowMSA: count_window_msa,
|
||||||
|
SparseConv3d: count_sparseconv,
|
||||||
|
SubMConv3d: count_sparseconv,
|
||||||
|
MultiheadAttention: count_mha
|
||||||
|
},
|
||||||
|
verbose=False
|
||||||
|
)
|
||||||
|
|
||||||
|
return macs, params
|
||||||
Loading…
Reference in New Issue