Created
May 16, 2021 10:15
-
-
Save AmirOfir/82882cc236745be2d4ea61900ded97c4 to your computer and use it in GitHub Desktop.
An implementation of MobileNet v1 with libtorch c++
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#pragma once | |
#include <torch/torch.h> | |
// Special thanks for wangvaton: https://github.com/wangvation/torch-mobilenet/blob/master/module/mobilenet.py | |
//# BSD 2 - Clause License | |
// | |
//# Copyright(c) 2019 wangvation.All rights reserved. | |
// | |
//# Redistribution and use in source and binary forms, with or without | |
//# modification, are permitted provided that the following conditions are met : | |
// | |
//# 1. Redistributions of source code must retain the above copyright notice, | |
//# this list of conditionsand the following disclaimer. | |
// | |
//# 2. Redistributions in binary form must reproduce the above copyright notice, | |
//# this list of conditionsand the following disclaimer in the documentation | |
//#and /or other materials provided with the distribution. | |
// | |
//# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |
//# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |
//# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |
//# ARE DISCLAIMED.IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE | |
//# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, | |
//# OR CONSEQUENTIAL DAMAGES(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF | |
// # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS | |
// # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN | |
// # CONTRACT, STRICT LIABILITY, OR TORT(INCLUDING NEGLIGENCE OR OTHERWISE) | |
// # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE | |
// # POSSIBILITY OF SUCH DAMAGE. | |
//# ============================================================================ | |
float makeDivisable(int v, int divisor, int minValue = -1) | |
{ | |
// This function is taken from the original tf repo. | |
// It ensures that all layers have a channel number that is divisible by 8 | |
// It can be seen here : https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py | |
// : param v : | |
// : param divisor : | |
// : param min_value : | |
// : return : | |
if (minValue == -1) | |
minValue = divisor; | |
int new_v = std::max(minValue, int((int(divisor / 2) + v) / divisor * divisor)); | |
// make sure that round down does not go down by more than 10 % . | |
if (new_v < 0.9 * v) | |
new_v += divisor; | |
return new_v; | |
} | |
class DepthSepConvImpl : public torch::nn::Cloneable<DepthSepConvImpl> | |
{ | |
int _in_channels; | |
int _out_channels; | |
torch::nn::Conv2d _depthwise_conv; | |
torch::nn::BatchNorm2d _bn1; | |
torch::nn::Conv2d _pointwise_conv; | |
torch::nn::BatchNorm2d _bn2; | |
public: | |
DepthSepConvImpl(int in_channels, int out_channels, int kernel_dim, int stride, int padding, int multiplier=1) | |
: _in_channels(makeDivisable(in_channels * multiplier, 8)), | |
_out_channels(makeDivisable(out_channels* multiplier, 8)), | |
_depthwise_conv(register_module("depthwise_conv", | |
torch::nn::Conv2d( | |
torch::nn::Conv2dOptions(_in_channels, _in_channels, kernel_dim).stride(stride).padding(padding).groups(in_channels).bias(false)))), | |
_bn1(register_module("bn1", | |
torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(_in_channels)))), | |
_pointwise_conv(register_module("pointwise_conv", | |
torch::nn::Conv2d(torch::nn::Conv2dOptions(_in_channels, _out_channels, 1).stride(1).groups(1).bias(false)))), | |
_bn2(register_module("bn2", | |
torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(_out_channels)))) | |
{ | |
} | |
torch::Tensor forward(const torch::Tensor& input) | |
{ | |
auto x = _depthwise_conv(input); | |
x = _bn1(x); | |
x = relu(x); | |
x = _pointwise_conv(x); | |
x = _bn2(x); | |
x = relu(x); | |
return x; | |
} | |
void reset() { } | |
virtual void train(bool on = true) override | |
{ | |
} | |
}; | |
TORCH_MODULE(DepthSepConv); | |
class MobileNetV1Impl : public torch::nn::Cloneable<MobileNetV1Impl> | |
{ | |
/* | |
docstring for MobileNetV1 | |
MobileNetV1 Body Architecture | |
| Type / Stride | Filter Shape | Input Size | Output Size | | |
| : ------------ | : ------------------ | : ------------ - | : ------------ - | | |
| Conv / s2 | 3 × 3 × 3 × 32 | 224 x 224 x 3 | 112 x 112 x 32 | | |
| Conv dw / s1 | 3 × 3 × 32 dw | 112 x 112 x 32 | 112 x 112 x 32 | | |
| Conv / s1 | 1 × 1 × 32 x 64 | 112 x 112 x 32 | 112 x 112 x 64 | | |
| Conv dw / s2 | 3 × 3 × 64 dw | 112 x 112 x 64 | 56 x 56 x 64 | | |
| Conv / s1 | 1 × 1 × 64 × 128 | 56 x 56 x 64 | 56 x 56 x 128 | | |
| Conv dw / s1 | 3 × 3 × 128 dw | 56 x 56 x 128 | 56 x 56 x 128 | | |
| Conv / s1 | 1 × 1 × 128 × 128 | 56 x 56 x 128 | 56 x 56 x 128 | | |
| Conv dw / s2 | 3 × 3 × 128 dw | 56 x 56 x 128 | 28 x 28 x 128 | | |
| Conv / s1 | 1 × 1 × 128 × 256 | 28 x 28 x 128 | 28 x 28 x 256 | | |
| Conv dw / s1 | 3 × 3 × 256 dw | 28 x 28 x 256 | 28 x 28 x 256 | | |
| Conv / s1 | 1 × 1 × 256 × 256 | 28 x 28 x 256 | 28 x 28 x 256 | | |
| Conv dw / s2 | 3 × 3 × 256 dw | 28 x 28 x 256 | 14 x 14 x 256 | | |
| Conv / s1 | 1 × 1 × 256 × 512 | 14 x 14 x 256 | 14 x 14 x 512 | | |
| Conv dw / s1 | 3 × 3 × 512 dw | 14 x 14 x 512 | 14 x 14 x 512 | | |
| Conv / s1 | 1 × 1 × 512 × 512 | 14 x 14 x 512 | 14 x 14 x 512 | | |
| Conv dw / s1 | 3 × 3 × 512 dw | 14 x 14 x 512 | 14 x 14 x 512 | | |
| Conv / s1 | 1 × 1 × 512 × 512 | 14 x 14 x 512 | 14 x 14 x 512 | | |
| Conv dw / s1 | 3 × 3 × 512 dw | 14 x 14 x 512 | 14 x 14 x 512 | | |
| Conv / s1 | 1 × 1 × 512 × 512 | 14 x 14 x 512 | 14 x 14 x 512 | | |
| Conv dw / s1 | 3 × 3 × 512 dw | 14 x 14 x 512 | 14 x 14 x 512 | | |
| Conv / s1 | 1 × 1 × 512 × 512 | 14 x 14 x 512 | 14 x 14 x 512 | | |
| Conv dw / s1 | 3 × 3 × 512 dw | 14 x 14 x 512 | 14 x 14 x 512 | | |
| Conv / s1 | 1 × 1 × 512 × 512 | 14 x 14 x 512 | 14 x 14 x 512 | | |
| Conv dw / s2 | 3 × 3 × 512 dw | 14 x 14 x 512 | 7 x 7 x 512 | | |
| Conv / s1 | 1 × 1 × 512 × 1024 | 7 x 7 x 512 | 7 x 7 x 1024 | | |
| Conv dw / s1 | 3 × 3 × 1024 dw | 7 x 7 x 1024 | 7 x 7 x 1024 | | |
| Conv / s1 | 1 × 1 × 1024 × 1024 | 7 x 7 x 1024 | 7 x 7 x 1024 | | |
| AvgPool / s1 | Pool 7 × 7 | 7 x 7 x 1024 | 1 x 1 x 1024 | | |
| FC / s1 | 1024 x 1000 | 1 x 1 x 1024 | 1 x 1 x 1000 | | |
| Softmax / s1 | Classifier | 1 x 1 x 1000 | 1 x 1 x 1000 | | |
*/ | |
int first_in_channel; | |
int last_out_channel; | |
int num_classes; | |
torch::nn::Sequential features; | |
torch::nn::Sequential classifier; | |
public: | |
MobileNetV1Impl(int resolution = 224, int num_classes = 1000, int multiplier = 1) | |
: | |
first_in_channel(makeDivisable(32 * multiplier, 8)), | |
last_out_channel(makeDivisable(1024 * multiplier, 8)), | |
num_classes(num_classes), | |
features(register_module("features", torch::nn::Sequential( | |
torch::nn::Conv2d(torch::nn::Conv2dOptions(3, first_in_channel, 3).stride(2).padding(1)), | |
DepthSepConv(32, 64, 3, 1,1, multiplier), | |
DepthSepConv(64, 128, 3, 2,1, multiplier), | |
DepthSepConv(128, 128, 3, 1,1, multiplier), | |
DepthSepConv(128, 256, 3, 2,1, multiplier), | |
DepthSepConv(256, 256, 3, 1,1, multiplier), | |
DepthSepConv(256, 512, 3, 2,1, multiplier), | |
DepthSepConv(512, 512, 3, 1,1, multiplier), | |
DepthSepConv(512, 512, 3, 1,1, multiplier), | |
DepthSepConv(512, 512, 3, 1,1, multiplier), | |
DepthSepConv(512, 512, 3, 1,1, multiplier), | |
DepthSepConv(512, 512, 3, 1,1, multiplier), | |
DepthSepConv(512, 1024, 3, 2,1, multiplier), | |
DepthSepConv(1024, 1024,3, 1,1, multiplier) | |
))), | |
classifier(register_module("classifier", torch::nn::Sequential( | |
torch::nn::AvgPool2d(torch::nn::AvgPool2dOptions((int)(resolution / 32))), // 7 x 7 x 1024 | |
torch::nn::Conv2d(torch::nn::Conv2dOptions(last_out_channel, num_classes, 1)), // 1 x 1 x 1024 | |
torch::nn::Softmax2d() | |
))) | |
{ | |
assert(resolution % 32 == 0); | |
} | |
torch::Tensor forward(const torch::Tensor& input) | |
{ | |
torch::Tensor x = features->forward(input); | |
x = classifier->forward(x); | |
x = x.view({ -1, num_classes }); | |
return x; | |
} | |
void reset() { } | |
virtual void train(bool on = true) override | |
{ | |
} | |
}; | |
TORCH_MODULE(MobileNetV1); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment