MNN/source/core/ConvolutionCommon.hpp

63 lines
2.3 KiB
C++
Raw Normal View History

//
// ConvolutionCommon.hpp
// MNN
//
2020-11-05 16:41:56 +08:00
// Created by MNN on 2020/03/02.
// Copyright © 2018, Alibaba Group Holding Limited
//
2020-11-05 16:41:56 +08:00
#ifndef ConvolutionCommon_hpp
#define ConvolutionCommon_hpp
#include "AutoStorage.h"
2020-11-05 16:41:56 +08:00
#include "Execution.hpp"
#include "MNN_generated.h"
namespace MNN {
class MNN_PUBLIC ConvolutionCommon : public Execution {
public:
struct Int8Common {
AutoStorage<int8_t> weight;
AutoStorage<float> alpha;
AutoStorage<float> weightFloat;
const IDSTQuan* quan;
2023-05-18 19:11:50 +08:00
bool asymmetric;
std::vector<int8_t> weightMap;
std::vector<uint8_t> weightReverseMap;
bool canUseInt4 = false;
};
2021-04-08 15:34:23 +08:00
static std::shared_ptr<Int8Common> load(const IDSTQuan* quan, bool forceFloat = false, bool forceInt8 = false);
2020-11-05 16:41:56 +08:00
static void getConvParameters(std::shared_ptr<ConvolutionCommon::Int8Common> *quanCommon, const MNN::Convolution2D *conv2d, const float** originWeight, int* originWeightSize);
2021-04-08 15:34:23 +08:00
static bool getConvInt8Parameters(const MNN::Convolution2D* conv2d, std::shared_ptr<Int8Common>& quanCommon,
2022-12-30 15:18:58 +08:00
const int8_t*& weight, int& weightSize, float*& scale, int32_t*& bias);
2020-11-05 16:41:56 +08:00
// Return padX, padY
2020-11-05 16:41:56 +08:00
static std::pair<int, int> convolutionPad(const Tensor* input, const Tensor* output,
const Convolution2DCommon* common);
2021-01-06 16:29:37 +08:00
// Return padLeft, padTop, padRight, padBottom
static std::tuple<int, int, int, int> convolutionPadFull(const Tensor* input, const Tensor* output,
const Convolution2DCommon* common);
2020-11-05 16:41:56 +08:00
static std::pair<int, int> convolutionTransposePad(const Tensor* input, const Tensor* output,
const Convolution2DCommon* common);
struct Im2ColParameter {
int32_t padX;
int32_t padY;
int32_t dilateX;
int32_t dilateY;
int32_t strideX;
int32_t strideY;
int32_t kernelX;
int32_t kernelY;
int32_t icDiv4;
int32_t kernelCountUnit;
int32_t iw;
int32_t ih;
int32_t ow;
int32_t oh;
int32_t srcZStep;
int32_t srcYStep;
2021-09-18 15:52:30 +08:00
int32_t packCUnit;
int32_t destICStride;
2020-11-05 16:41:56 +08:00
};
};
2020-11-05 16:41:56 +08:00
} // namespace MNN
#endif