2019-07-11 13:56:52 +08:00
|
|
|
//
|
|
|
|
|
// TensorStatistic.hpp
|
|
|
|
|
// MNN
|
|
|
|
|
//
|
|
|
|
|
// Created by MNN on 2019/06/30.
|
|
|
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
|
|
|
//
|
|
|
|
|
|
2019-07-25 13:36:35 +08:00
|
|
|
#include <memory>
|
2019-07-11 13:56:52 +08:00
|
|
|
#include <vector>
|
2019-12-27 22:16:57 +08:00
|
|
|
#include <MNN/Tensor.hpp>
|
2019-08-07 16:44:09 +08:00
|
|
|
#include <string>
|
2019-07-25 13:36:35 +08:00
|
|
|
|
|
|
|
|
enum GET_THRESHOLD_METHOD {
|
|
|
|
|
THRESHOLD_MAX = 0,
|
|
|
|
|
THRESHOLD_KL = 1,
|
|
|
|
|
};
|
2019-07-11 13:56:52 +08:00
|
|
|
|
|
|
|
|
class TensorStatistic {
|
|
|
|
|
public:
|
2021-01-06 19:54:08 +08:00
|
|
|
TensorStatistic(const MNN::Tensor* tensor, std::string method, const std::string& name, float featureClampValue, int binNumber = 2048, GET_THRESHOLD_METHOD thresholdMethod = THRESHOLD_KL);
|
2019-07-11 13:56:52 +08:00
|
|
|
~TensorStatistic() {
|
|
|
|
|
// Do nothing
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void resetUpdatedDistributionFlag() {
|
|
|
|
|
mUpdatedDistributionFlag = false;
|
|
|
|
|
}
|
|
|
|
|
void resetUpdatedRangeFlags() {
|
|
|
|
|
mUpdatedRangeFlags = false;
|
|
|
|
|
}
|
|
|
|
|
void updateRange();
|
|
|
|
|
void resetDistribution();
|
|
|
|
|
void updateDistribution();
|
|
|
|
|
|
2019-07-25 13:36:35 +08:00
|
|
|
void setThresholdMethod(GET_THRESHOLD_METHOD thresholdMethod);
|
2019-08-07 16:44:09 +08:00
|
|
|
void setChannelWise(bool mergeChannel);
|
2019-07-25 13:36:35 +08:00
|
|
|
|
2019-07-11 13:56:52 +08:00
|
|
|
std::vector<float> finishAndCompute();
|
|
|
|
|
|
2019-08-07 16:44:09 +08:00
|
|
|
// only this one for ADMM
|
|
|
|
|
std::vector<float> computeScaleADMM();
|
|
|
|
|
|
2021-01-05 20:40:46 +08:00
|
|
|
std::string name() {
|
|
|
|
|
return mName;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool visited() {
|
|
|
|
|
return mVisited;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void setVisited(bool visited) {
|
|
|
|
|
mVisited = visited;
|
|
|
|
|
}
|
|
|
|
|
|
2021-01-07 15:17:23 +08:00
|
|
|
std::pair<std::vector<float>, float> fakeQuantFeature();
|
|
|
|
|
float computeDistance(std::vector<float> fakeQuantedFeature);
|
2021-01-05 20:40:46 +08:00
|
|
|
|
2019-07-11 13:56:52 +08:00
|
|
|
private:
|
|
|
|
|
int _computeThreshold(const std::vector<float>& distribution);
|
|
|
|
|
std::vector<std::pair<float, float>> mRangePerChannel;
|
|
|
|
|
std::vector<float> mIntervals;
|
|
|
|
|
std::vector<bool> mValidChannel;
|
|
|
|
|
std::vector<std::vector<float>> mDistribution;
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<MNN::Tensor> mHostTensor;
|
|
|
|
|
const MNN::Tensor* mOriginTensor;
|
|
|
|
|
int mBinNumber;
|
|
|
|
|
bool mUpdatedDistributionFlag = false;
|
|
|
|
|
bool mUpdatedRangeFlags = false;
|
|
|
|
|
|
2019-07-25 13:36:35 +08:00
|
|
|
bool mMergeChannel = true;
|
2019-08-07 16:44:09 +08:00
|
|
|
std::string mName;
|
2019-07-25 13:36:35 +08:00
|
|
|
GET_THRESHOLD_METHOD mThresholdMethod = THRESHOLD_KL;
|
2021-01-05 20:40:46 +08:00
|
|
|
bool mVisited = false;
|
|
|
|
|
std::vector<float> mScales;
|
2021-01-06 19:54:08 +08:00
|
|
|
float mFeatureClampValue = 127.0f;
|
2019-07-11 13:56:52 +08:00
|
|
|
};
|