320 lines
11 KiB
Python
320 lines
11 KiB
Python
"""
|
|
Creates a MobileNetV4 Model as defined in:
|
|
Danfeng Qin, Chas Leichner, Manolis Delakis, Marco Fornoni, Shixin Luo, Fan Yang, Weijun Wang, Colby Banbury, Chengxi Ye, Berkin Akin, Vaibhav Aggarwal, Tenghui Zhu, Daniele Moro, Andrew Howard. (2024).
|
|
MobileNetV4 - Universal Models for the Mobile Ecosystem
|
|
arXiv preprint arXiv:2404.10518.
|
|
"""
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import math
|
|
|
|
|
|
__all__ = ['mobilenetv4_conv_small', 'mobilenetv4_conv_medium', 'mobilenetv4_conv_large',
|
|
'mobilenetv4_hybrid_medium', 'mobilenetv4_hybrid_large']
|
|
|
|
|
|
def make_divisible(value, divisor, min_value=None, round_down_protect=True):
|
|
if min_value is None:
|
|
min_value = divisor
|
|
new_value = max(min_value, int(value + divisor / 2) // divisor * divisor)
|
|
# Make sure that round down does not go down by more than 10%.
|
|
if round_down_protect and new_value < 0.9 * value:
|
|
new_value += divisor
|
|
return new_value
|
|
|
|
|
|
class ConvBN(nn.Module):
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1):
|
|
super(ConvBN, self).__init__()
|
|
self.block = nn.Sequential(
|
|
nn.Conv2d(in_channels, out_channels, kernel_size, stride, (kernel_size - 1)//2, bias=False),
|
|
nn.BatchNorm2d(out_channels),
|
|
nn.ReLU(inplace=True),
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.block(x)
|
|
|
|
|
|
class UniversalInvertedBottleneck(nn.Module):
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
expand_ratio,
|
|
start_dw_kernel_size,
|
|
middle_dw_kernel_size,
|
|
stride,
|
|
middle_dw_downsample: bool = True,
|
|
use_layer_scale: bool = False,
|
|
layer_scale_init_value: float = 1e-5):
|
|
super(UniversalInvertedBottleneck, self).__init__()
|
|
self.start_dw_kernel_size = start_dw_kernel_size
|
|
self.middle_dw_kernel_size = middle_dw_kernel_size
|
|
|
|
if start_dw_kernel_size:
|
|
self.start_dw_conv = nn.Conv2d(in_channels, in_channels, start_dw_kernel_size,
|
|
stride if not middle_dw_downsample else 1,
|
|
(start_dw_kernel_size - 1) // 2,
|
|
groups=in_channels, bias=False)
|
|
self.start_dw_norm = nn.BatchNorm2d(in_channels)
|
|
|
|
expand_channels = make_divisible(in_channels * expand_ratio, 8)
|
|
self.expand_conv = nn.Conv2d(in_channels, expand_channels, 1, 1, bias=False)
|
|
self.expand_norm = nn.BatchNorm2d(expand_channels)
|
|
self.expand_act = nn.ReLU(inplace=True)
|
|
|
|
if middle_dw_kernel_size:
|
|
self.middle_dw_conv = nn.Conv2d(expand_channels, expand_channels, middle_dw_kernel_size,
|
|
stride if middle_dw_downsample else 1,
|
|
(middle_dw_kernel_size - 1) // 2,
|
|
groups=expand_channels, bias=False)
|
|
self.middle_dw_norm = nn.BatchNorm2d(expand_channels)
|
|
self.middle_dw_act = nn.ReLU(inplace=True)
|
|
|
|
self.proj_conv = nn.Conv2d(expand_channels, out_channels, 1, 1, bias=False)
|
|
self.proj_norm = nn.BatchNorm2d(out_channels)
|
|
|
|
if use_layer_scale:
|
|
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((out_channels)), requires_grad=True)
|
|
|
|
self.use_layer_scale = use_layer_scale
|
|
self.identity = stride == 1 and in_channels == out_channels
|
|
|
|
def forward(self, x):
|
|
shortcut = x
|
|
|
|
if self.start_dw_kernel_size:
|
|
x = self.start_dw_conv(x)
|
|
x = self.start_dw_norm(x)
|
|
|
|
x = self.expand_conv(x)
|
|
x = self.expand_norm(x)
|
|
x = self.expand_act(x)
|
|
|
|
if self.middle_dw_kernel_size:
|
|
x = self.middle_dw_conv(x)
|
|
x = self.middle_dw_norm(x)
|
|
x = self.middle_dw_act(x)
|
|
|
|
x = self.proj_conv(x)
|
|
x = self.proj_norm(x)
|
|
|
|
if self.use_layer_scale:
|
|
x = self.gamma * x
|
|
|
|
return x + shortcut if self.identity else x
|
|
|
|
|
|
class MobileNetV4(nn.Module):
|
|
def __init__(self, block_specs, num_classes=1):
|
|
super(MobileNetV4, self).__init__()
|
|
|
|
c = 3
|
|
layers = []
|
|
for block_type, *block_cfg in block_specs:
|
|
if block_type == 'conv_bn':
|
|
block = ConvBN
|
|
k, s, f = block_cfg
|
|
layers.append(block(c, f, k, s))
|
|
elif block_type == 'uib':
|
|
block = UniversalInvertedBottleneck
|
|
start_k, middle_k, s, f, e = block_cfg
|
|
layers.append(block(c, f, e, start_k, middle_k, s))
|
|
else:
|
|
raise NotImplementedError
|
|
c = f
|
|
self.features = nn.Sequential(*layers)
|
|
# building last several layers
|
|
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
|
hidden_channels = 1280
|
|
self.conv = ConvBN(c, hidden_channels, 1)
|
|
self.classifier = nn.Linear(hidden_channels, num_classes)
|
|
|
|
self._initialize_weights()
|
|
|
|
def forward(self, x):
|
|
x = self.features(x)
|
|
x = self.avgpool(x)
|
|
x = self.conv(x)
|
|
x = x.view(x.size(0), -1)
|
|
x = self.classifier(x)
|
|
return x
|
|
|
|
def _initialize_weights(self):
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
|
m.weight.data.normal_(0, math.sqrt(2. / n))
|
|
if m.bias is not None:
|
|
m.bias.data.zero_()
|
|
elif isinstance(m, nn.BatchNorm2d):
|
|
m.weight.data.fill_(1)
|
|
m.bias.data.zero_()
|
|
elif isinstance(m, nn.Linear):
|
|
m.weight.data.normal_(0, 0.01)
|
|
m.bias.data.zero_()
|
|
|
|
|
|
def mobilenetv4_conv_tiny(**kwargs):
|
|
"""
|
|
Constructs a MobileNetV4-Conv-Tiny model
|
|
"""
|
|
block_specs = [
|
|
# conv_bn, kernel_size, stride, out_channels
|
|
# uib, start_ks, middle_ks, stride, out_channels, expand_ratio
|
|
# 112px
|
|
('conv_bn', 3, 2, 24),
|
|
# 56px
|
|
('conv_bn', 3, 2, 32),
|
|
('conv_bn', 1, 1, 32),
|
|
# 28px
|
|
('conv_bn', 3, 2, 48),
|
|
('conv_bn', 1, 1, 32),
|
|
# stage 4
|
|
('uib', 5, 5, 2, 48, 3.0), # ExtraDW
|
|
('uib', 0, 3, 1, 48, 2.0), # IB
|
|
('uib', 3, 0, 1, 48, 3.0), # ConvNext
|
|
# stage 5
|
|
('uib', 3, 3, 2, 64, 3.0), # ExtraDW
|
|
('uib', 0, 3, 1, 64, 2.0), # IB
|
|
('conv_bn', 1, 1, 64), # Conv
|
|
]
|
|
return MobileNetV4(block_specs, **kwargs)
|
|
|
|
|
|
def mobilenetv4_conv_small(**kwargs):
|
|
"""
|
|
Constructs a MobileNetV4-Conv-Small model
|
|
"""
|
|
block_specs = [
|
|
# conv_bn, kernel_size, stride, out_channels
|
|
# uib, start_dw_kernel_size, middle_dw_kernel_size, stride, out_channels, expand_ratio
|
|
# 112px
|
|
('conv_bn', 3, 2, 32),
|
|
# 56px
|
|
('conv_bn', 3, 2, 32),
|
|
('conv_bn', 1, 1, 32),
|
|
# 28px
|
|
('conv_bn', 3, 2, 96),
|
|
('conv_bn', 1, 1, 64),
|
|
# 14px
|
|
('uib', 5, 5, 2, 96, 3.0), # ExtraDW
|
|
('uib', 0, 3, 1, 96, 2.0), # IB
|
|
('uib', 0, 3, 1, 96, 2.0), # IB
|
|
('uib', 0, 3, 1, 96, 2.0), # IB
|
|
('uib', 0, 3, 1, 96, 2.0), # IB
|
|
('uib', 3, 0, 1, 96, 4.0), # ConvNext
|
|
# 7px
|
|
('uib', 3, 3, 2, 128, 6.0), # ExtraDW
|
|
('uib', 5, 5, 1, 128, 4.0), # ExtraDW
|
|
('uib', 0, 5, 1, 128, 4.0), # IB
|
|
('uib', 0, 5, 1, 128, 3.0), # IB
|
|
('uib', 0, 3, 1, 128, 4.0), # IB
|
|
('uib', 0, 3, 1, 128, 4.0), # IB
|
|
('conv_bn', 1, 1, 960), # Conv
|
|
]
|
|
return MobileNetV4(block_specs, **kwargs)
|
|
|
|
|
|
def mobilenetv4_conv_medium(**kwargs):
|
|
"""
|
|
Constructs a MobileNetV4-Conv-Medium model
|
|
"""
|
|
block_specs = [
|
|
('conv_bn', 3, 2, 32),
|
|
# 2nd stage
|
|
('conv_bn', 3, 2, 128),
|
|
('conv_bn', 1, 1, 48),
|
|
# 3rd stage
|
|
('uib', 3, 5, 2, 80, 4.0),
|
|
('uib', 3, 3, 1, 80, 2.0),
|
|
# 4th stage
|
|
('uib', 3, 5, 2, 160, 6.0),
|
|
('uib', 3, 3, 1, 160, 4.0),
|
|
('uib', 3, 3, 1, 160, 4.0),
|
|
('uib', 3, 5, 1, 160, 4.0),
|
|
('uib', 3, 3, 1, 160, 4.0),
|
|
('uib', 3, 0, 1, 160, 4.0),
|
|
('uib', 0, 0, 1, 160, 2.0),
|
|
('uib', 3, 0, 1, 160, 4.0),
|
|
# 5th stage
|
|
('uib', 5, 5, 2, 256, 6.0),
|
|
('uib', 5, 5, 1, 256, 4.0),
|
|
('uib', 3, 5, 1, 256, 4.0),
|
|
('uib', 3, 5, 1, 256, 4.0),
|
|
('uib', 0, 0, 1, 256, 4.0),
|
|
('uib', 3, 0, 1, 256, 4.0),
|
|
('uib', 3, 5, 1, 256, 2.0),
|
|
('uib', 5, 5, 1, 256, 4.0),
|
|
('uib', 0, 0, 1, 256, 4.0),
|
|
('uib', 0, 0, 1, 256, 4.0),
|
|
('uib', 5, 0, 1, 256, 2.0),
|
|
# FC layers
|
|
('conv_bn', 1, 1, 960),
|
|
]
|
|
|
|
return MobileNetV4(block_specs, **kwargs)
|
|
|
|
|
|
def mobilenetv4_conv_large(**kwargs):
|
|
"""
|
|
Constructs a MobileNetV4-Conv-Large model
|
|
"""
|
|
block_specs = [
|
|
('conv_bn', 3, 2, 24),
|
|
('conv_bn', 3, 2, 96),
|
|
('conv_bn', 1, 1, 48),
|
|
('uib', 3, 5, 2, 96, 4.0),
|
|
('uib', 3, 3, 1, 96, 4.0),
|
|
('uib', 3, 5, 2, 192, 4.0),
|
|
('uib', 3, 3, 1, 192, 4.0),
|
|
('uib', 3, 3, 1, 192, 4.0),
|
|
('uib', 3, 3, 1, 192, 4.0),
|
|
('uib', 3, 5, 1, 192, 4.0),
|
|
('uib', 5, 3, 1, 192, 4.0),
|
|
('uib', 5, 3, 1, 192, 4.0),
|
|
('uib', 5, 3, 1, 192, 4.0),
|
|
('uib', 5, 3, 1, 192, 4.0),
|
|
('uib', 5, 3, 1, 192, 4.0),
|
|
('uib', 3, 0, 1, 192, 4.0),
|
|
('uib', 5, 5, 2, 512, 4.0),
|
|
('uib', 5, 5, 1, 512, 4.0),
|
|
('uib', 5, 5, 1, 512, 4.0),
|
|
('uib', 5, 5, 1, 512, 4.0),
|
|
('uib', 5, 0, 1, 512, 4.0),
|
|
('uib', 5, 3, 1, 512, 4.0),
|
|
('uib', 5, 0, 1, 512, 4.0),
|
|
('uib', 5, 0, 1, 512, 4.0),
|
|
('uib', 5, 3, 1, 512, 4.0),
|
|
('uib', 5, 5, 1, 512, 4.0),
|
|
('uib', 5, 0, 1, 512, 4.0),
|
|
('uib', 5, 0, 1, 512, 4.0),
|
|
('uib', 5, 0, 1, 512, 4.0),
|
|
('conv_bn', 1, 1, 960),
|
|
]
|
|
|
|
return MobileNetV4(block_specs, **kwargs)
|
|
|
|
|
|
# count flops and params in if __name__ == '__main__'
|
|
if __name__ == '__main__':
|
|
from thop import profile
|
|
model = mobilenetv4_conv_tiny(num_classes=1)
|
|
model.eval()
|
|
flops, params = profile(model, inputs=(torch.randn(1, 3, 640, 360),))
|
|
print(f'FLOPs: {flops/1e9:.2f} G')
|
|
print(f'Params: {params/1e6:.2f} M')
|
|
dummy_input = torch.randn(1, 3, 640, 360)
|
|
xfeat = model
|
|
# speed test warmup 50, test 200
|
|
for i in range(50):
|
|
xfeat(dummy_input)
|
|
from time import perf_counter
|
|
start_t = perf_counter()
|
|
for i in range(200):
|
|
xfeat(dummy_input)
|
|
# time in ms 1000
|
|
print(f"Speed: {(perf_counter()-start_t)/200*1000:.2f} ms")
|
|
quit()
|