mirror of https://github.com/alibaba/MNN.git
228 lines
8.4 KiB
C++
228 lines
8.4 KiB
C++
//
|
|
// quantizeWeight.cpp
|
|
// MNN
|
|
//
|
|
// Created by MNN on 2019/04/21.
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
//
|
|
|
|
#include "quantizeWeight.hpp"
|
|
#include <math.h>
|
|
#include <algorithm>
|
|
#include <cmath>
|
|
#include "logkit.h"
|
|
#include <MNN/MNNDefine.h>
|
|
|
|
void InitAlpha(const float* weight, const int weightNum, const int kernelNum, float* alpha, const float weightClampValue) {
|
|
const int kernelDim = weightNum / kernelNum;
|
|
|
|
for (int i = 0; i < kernelNum; i++) {
|
|
float avg = 0;
|
|
float max = 0;
|
|
float absVal;
|
|
|
|
for (int j = 0; j < kernelDim; j++) {
|
|
absVal = std::fabs(weight[i * kernelDim + j]);
|
|
avg += absVal;
|
|
if (absVal > max) {
|
|
max = absVal;
|
|
}
|
|
}
|
|
avg = avg / float(kernelDim);
|
|
|
|
if (weightClampValue > 1) {
|
|
alpha[i] = max / (weightClampValue * 1.25);
|
|
}
|
|
else {
|
|
alpha[i] = avg;
|
|
}
|
|
}
|
|
}
|
|
|
|
void UpdateQuantizedWeights(const float* weight, const int weightNum, const int kernelNum, float* alpha,
|
|
const float weightClampValue, int8_t* quantizedWeight) {
|
|
const int kernelDim = weightNum / kernelNum;
|
|
const float eps = 1e-9f;
|
|
float weightQuan;
|
|
CHECK((int)weightClampValue >= 7) << "quantization bits less than 4 not supported yet.";
|
|
|
|
for (int i = 0; i < weightNum; i++) {
|
|
weightQuan = weight[i] / (alpha[i / kernelDim]+ eps);
|
|
quantizedWeight[i] = std::min(weightClampValue, std::max(-weightClampValue, std::roundf(weightQuan)));
|
|
}
|
|
}
|
|
|
|
void UpdateAlpha(const float* weight, const int weightNum, const int kernelNum, float* alpha, int8_t* quantizedWeight) {
|
|
const int kernelDim = weightNum / kernelNum;
|
|
const float eps = 1e-9f;
|
|
|
|
for (int i = 0; i < kernelNum; i++) {
|
|
const int offset = i * kernelDim;
|
|
float sum1 = 0;
|
|
float sum2 = 0;
|
|
|
|
for (int j = 0; j < kernelDim; j++) {
|
|
sum1 += weight[offset + j] * quantizedWeight[offset + j];
|
|
sum2 += quantizedWeight[offset + j] * quantizedWeight[offset + j];
|
|
}
|
|
alpha[i] = sum1 / (sum2+eps);
|
|
}
|
|
}
|
|
|
|
// weight format is [co, ci, kh, kw]
|
|
int QuantizeWeightADMM(const float* weight, const int weightNum, int8_t* quantizedWeight, float* alpha,
|
|
const int kernelNum, const float weightClampValue) {
|
|
// channels: co
|
|
DCHECK((weightNum % kernelNum) == 0) << "weight size error!";
|
|
const int kernelDim = weightNum / kernelNum; // ci * kh * kw
|
|
|
|
InitAlpha(weight, weightNum, kernelNum, alpha, weightClampValue);
|
|
|
|
int iter = 0;
|
|
float diffRate = 1;
|
|
float preSum = 0;
|
|
float curSum = 0;
|
|
const int maxIter = 1000;
|
|
|
|
for (int i = 0; i < weightNum; i++){
|
|
preSum += std::fabs(weight[i]);
|
|
}
|
|
// update weights quan
|
|
while(iter < maxIter) {
|
|
UpdateQuantizedWeights(weight, weightNum, kernelNum, alpha, weightClampValue, quantizedWeight);
|
|
UpdateAlpha(weight, weightNum, kernelNum, alpha, quantizedWeight);
|
|
iter++;
|
|
}
|
|
|
|
for (int i = 0; i < weightNum; i++){
|
|
curSum += std::fabs(quantizedWeight[i]*alpha[i/kernelDim]);
|
|
}
|
|
DLOG(INFO) << "iter: " << iter << " with diff "<< preSum-curSum;
|
|
return 0;
|
|
}
|
|
|
|
// weight format is [co, ci, kh, kw]
|
|
int SymmetricQuantizeWeight(const float* weight, const int size, int8_t* quantizedWeight, float* scale,
|
|
const int channels, float weightClampValue) {
|
|
DCHECK((size % channels) == 0) << "weight size error!";
|
|
const int channelStride = size / channels;
|
|
const int quantizedMaxValue = weightClampValue;
|
|
|
|
for (int c = 0; c < channels; ++c) {
|
|
const auto weightChannelStart = weight + c * channelStride;
|
|
auto quantizedWeightChannelStart = quantizedWeight + c * channelStride;
|
|
auto minmaxValue = std::minmax_element(weightChannelStart, weightChannelStart + channelStride);
|
|
const float dataAbsMax = std::fmax(std::fabs(*minmaxValue.first), std::fabs(*minmaxValue.second));
|
|
|
|
float scaleDataToInt8 = 1.0f;
|
|
if (dataAbsMax == 0) {
|
|
scale[c] = 0.0f;
|
|
} else {
|
|
scale[c] = dataAbsMax / quantizedMaxValue;
|
|
scaleDataToInt8 = quantizedMaxValue / dataAbsMax;
|
|
}
|
|
|
|
for (int i = 0; i < channelStride; ++i) {
|
|
const int32_t quantizedInt8Value = static_cast<int32_t>(roundf(weightChannelStart[i] * scaleDataToInt8));
|
|
quantizedWeightChannelStart[i] =
|
|
std::min(quantizedMaxValue, std::max(-quantizedMaxValue, quantizedInt8Value));
|
|
}
|
|
}
|
|
|
|
return 0;
|
|
}
|
|
|
|
int QuantizeConvPerChannel(const float* weight, const int size, const float* bias, int8_t* quantizedWeight,
|
|
int32_t* quantizedBias, float* scale, const float inputScale, const float outputScale,
|
|
const int inputChannel, const int outputChannel, std::string method, float weightClampValue, bool mergeChannel) {
|
|
const int icXoc = inputChannel * outputChannel;
|
|
DCHECK(size % icXoc == 0) << "Input Data Size Error!";
|
|
|
|
std::vector<float> quantizedWeightScale(outputChannel);
|
|
|
|
float inputScalexWeight = 1.0f;
|
|
if (mergeChannel) {
|
|
if (method == "MAX_ABS"){
|
|
SymmetricQuantizeWeight(weight, size, quantizedWeight, quantizedWeightScale.data(), outputChannel, weightClampValue);
|
|
}
|
|
else if (method == "ADMM") {
|
|
QuantizeWeightADMM(weight, size, quantizedWeight, quantizedWeightScale.data(), outputChannel, weightClampValue);
|
|
}
|
|
inputScalexWeight = inputScale;
|
|
} else {
|
|
const int kernelSize = size / icXoc;
|
|
const int ocStride = size / outputChannel;
|
|
|
|
std::vector<float> weightMultiByInputScale(size);
|
|
for (int oc = 0; oc < outputChannel; ++oc) {
|
|
for (int ic = 0; ic < inputChannel; ++ic) {
|
|
for (int i = 0; i < kernelSize; ++i) {
|
|
const int index = oc * ocStride + ic * kernelSize + i;
|
|
weightMultiByInputScale[index] = inputScale * weight[index];
|
|
}
|
|
}
|
|
}
|
|
if (method == "MAX_ABS"){
|
|
SymmetricQuantizeWeight(weightMultiByInputScale.data(), size, quantizedWeight, quantizedWeightScale.data(), outputChannel, weightClampValue);
|
|
}
|
|
else if (method == "ADMM") {
|
|
QuantizeWeightADMM(weightMultiByInputScale.data(), size, quantizedWeight, quantizedWeightScale.data(), outputChannel, weightClampValue);
|
|
}
|
|
}
|
|
|
|
for (int i = 0; i < outputChannel; ++i) {
|
|
if (fabs(outputScale) <= 1e-6) {
|
|
scale[i] = 0.0f;
|
|
} else {
|
|
scale[i] = inputScalexWeight * quantizedWeightScale[i] / outputScale;
|
|
}
|
|
}
|
|
|
|
if (bias) {
|
|
for (int i = 0; i < outputChannel; ++i) {
|
|
if (fabs(inputScalexWeight) <= 1e-6 || fabs(quantizedWeightScale[i]) <= 1e-6) {
|
|
quantizedBias[i] = 0;
|
|
} else {
|
|
quantizedBias[i] = static_cast<int32_t>(bias[i] / (inputScalexWeight * quantizedWeightScale[i]));
|
|
}
|
|
}
|
|
}
|
|
|
|
return 0;
|
|
}
|
|
|
|
int QuantizeDepthwiseConv(const float* weight, const int size, const float* bias, int8_t* quantizedWeight,
|
|
int32_t* quantizedBias, float* scale, const float inputScale, const float outputScale,
|
|
const int inputChannel, const int outputChannel, std::string method, float weightClampValue, bool mergeChannel) {
|
|
DCHECK(inputChannel == outputChannel) << "Input Data Size Error!";
|
|
|
|
std::vector<float> quantizedWeightScale(inputChannel);
|
|
if (method == "MAX_ABS") {
|
|
SymmetricQuantizeWeight(weight, size, quantizedWeight, quantizedWeightScale.data(), inputChannel, weightClampValue);
|
|
}
|
|
else if (method == "ADMM") {
|
|
QuantizeWeightADMM(weight, size, quantizedWeight, quantizedWeightScale.data(), inputChannel, weightClampValue);
|
|
}
|
|
|
|
for (int c = 0; c < inputChannel; ++c) {
|
|
const int index = c;
|
|
if (fabs(outputScale) <= 1e-6) {
|
|
scale[index] = 0.0f;
|
|
} else {
|
|
scale[index] = inputScale * quantizedWeightScale[c] / outputScale;
|
|
}
|
|
}
|
|
|
|
if (bias) {
|
|
for (int i = 0; i < outputChannel; ++i) {
|
|
if (fabs(inputScale) <= 1e-6 || fabs(quantizedWeightScale[i]) <= 1e-6) {
|
|
quantizedBias[i] = 0;
|
|
} else {
|
|
quantizedBias[i] = static_cast<int32_t>(bias[i] / (inputScale * quantizedWeightScale[i]));
|
|
}
|
|
}
|
|
}
|
|
|
|
return 0;
|
|
}
|