mirror of https://github.com/alibaba/MNN.git
				
				
				
			
		
			
	
	
		
			211 lines
		
	
	
		
			6.0 KiB
		
	
	
	
		
			C++
		
	
	
	
		
		
			
		
	
	
			211 lines
		
	
	
		
			6.0 KiB
		
	
	
	
		
			C++
		
	
	
	
|  | //
 | ||
|  | //  Initializer.cpp
 | ||
|  | //  MNN
 | ||
|  | //
 | ||
|  | //  Created by MNN on 2019/11/28.
 | ||
|  | //  Copyright © 2018, Alibaba Group Holding Limited
 | ||
|  | //
 | ||
|  | 
 | ||
|  | #include "Initializer.hpp"
 | ||
|  | #include <MNN/expr/ExprCreator.hpp>
 | ||
|  | #include <cmath>
 | ||
|  | #include <vector>
 | ||
|  | #include "Distributions.hpp"
 | ||
|  | #include "RandomGenerator.hpp"
 | ||
|  | 
 | ||
|  | namespace MNN { | ||
|  | namespace Express { | ||
|  | 
 | ||
|  | Express::VARP Initializer::createConstVar(Express::INTS dim, Express::Dimensionformat format) { | ||
|  |     auto res = Express::_Input(dim, format, halide_type_of<float>()); | ||
|  |     this->onExecute(res); | ||
|  |     res.fix(Express::VARP::CONSTANT); | ||
|  |     return res; | ||
|  | } | ||
|  | 
 | ||
|  | class ConstantInitializer : public Initializer { | ||
|  | public: | ||
|  |     ConstantInitializer(float value) : mConstant(value) { | ||
|  |     } | ||
|  | 
 | ||
|  |     virtual void onExecute(Express::VARP p) override { | ||
|  |         const int count = p->getInfo()->size; | ||
|  |         MNN_ASSERT(count > 0); | ||
|  |         auto ptr = p->writeMap<float>(); | ||
|  |         for (int i = 0; i < count; i++) { | ||
|  |             ptr[i] = mConstant; | ||
|  |         } | ||
|  |     } | ||
|  | 
 | ||
|  | private: | ||
|  |     float mConstant; | ||
|  | }; | ||
|  | Initializer* Initializer::constValue(float value) { | ||
|  |     return new ConstantInitializer(value); | ||
|  | } | ||
|  | 
 | ||
|  | class UniformInitializer : public Initializer { | ||
|  | public: | ||
|  |     UniformInitializer(float min = 0, float max = 1) { | ||
|  |         mMin = min; | ||
|  |         mMax = max; | ||
|  |     } | ||
|  | 
 | ||
|  |     virtual void onExecute(Express::VARP p) override { | ||
|  |         const int count = p->getInfo()->size; | ||
|  |         MNN_ASSERT(count > 0); | ||
|  |         Distributions::uniform(count, mMin, mMax, p->writeMap<float>(), RandomGenerator::generator()); | ||
|  |     } | ||
|  | 
 | ||
|  | private: | ||
|  |     float mMin; | ||
|  |     float mMax; | ||
|  | }; | ||
|  | 
 | ||
|  | Initializer* Initializer::uniform(float minValue, float maxValue) { | ||
|  |     return new UniformInitializer(minValue, maxValue); | ||
|  | } | ||
|  | 
 | ||
|  | class XavierInitializer : public Initializer { | ||
|  | public: | ||
|  |     XavierInitializer(VarianceNorm norm = FANIN) { | ||
|  |         mNorm = norm; | ||
|  |     } | ||
|  | 
 | ||
|  |     virtual void onExecute(Express::VARP p) override { | ||
|  |         const int count = p->getInfo()->size; | ||
|  |         MNN_ASSERT(count > 0); | ||
|  |         const std::vector<int> dims = p->getInfo()->dim; | ||
|  |         // referenced from Caffe
 | ||
|  |         // https://github.com/BVLC/caffe/blob/master/include/caffe/filler.hpp
 | ||
|  |         int fanIn  = count / dims[0]; | ||
|  |         int fanOut = dims.size() > 1 ? count / dims[1] : count; | ||
|  |         float n    = fanIn; // default: FANIN
 | ||
|  |         if (mNorm == VarianceNorm::AVERAGE) { | ||
|  |             n = (fanIn + fanOut) / 2.0f; | ||
|  |         } else if (mNorm == VarianceNorm::FANOUT) { | ||
|  |             n = fanOut; | ||
|  |         } | ||
|  |         float scale = sqrtf(3.0f / n); | ||
|  | 
 | ||
|  |         Distributions::uniform(count, -scale, scale, p->writeMap<float>(), RandomGenerator::generator()); | ||
|  |     } | ||
|  | 
 | ||
|  | private: | ||
|  |     VarianceNorm mNorm; | ||
|  | }; | ||
|  | Initializer* Initializer::xavier(VarianceNorm norm) { | ||
|  |     return new XavierInitializer(norm); | ||
|  | } | ||
|  | 
 | ||
|  | class GaussianInitializer : public Initializer { | ||
|  | public: | ||
|  |     GaussianInitializer(float mean = 0, float std = 1) { | ||
|  |         mMean = mean; | ||
|  |         mStd  = std; | ||
|  |     } | ||
|  | 
 | ||
|  |     virtual void onExecute(Express::VARP p) override { | ||
|  |         const int count = p->getInfo()->size; | ||
|  |         MNN_ASSERT(count > 0); | ||
|  |         Distributions::gaussian(count, mMean, mStd, p->writeMap<float>(), RandomGenerator::generator()); | ||
|  |     } | ||
|  | 
 | ||
|  | private: | ||
|  |     float mMean; | ||
|  |     float mStd; | ||
|  | }; | ||
|  | Initializer* Initializer::gauss(float mean, float std) { | ||
|  |     return new GaussianInitializer(mean, std); | ||
|  | } | ||
|  | 
 | ||
|  | class MSRAInitializer : public Initializer { | ||
|  | public: | ||
|  |     MSRAInitializer(VarianceNorm norm = FANIN) { | ||
|  |         mNorm = norm; | ||
|  |     } | ||
|  | 
 | ||
|  |     virtual void onExecute(Express::VARP p) override { | ||
|  |         const int count = p->getInfo()->size; | ||
|  |         MNN_ASSERT(count > 0); | ||
|  |         const std::vector<int> dims = p->getInfo()->dim; | ||
|  |         // referenced from Caffe
 | ||
|  |         // https://github.com/BVLC/caffe/blob/master/include/caffe/filler.hpp
 | ||
|  |         int fanIn  = count / dims[0]; | ||
|  |         int fanOut = dims.size() > 1 ? count / dims[1] : count; | ||
|  |         float n    = fanIn; // default: FANIN
 | ||
|  |         if (mNorm == VarianceNorm::AVERAGE) { | ||
|  |             n = (fanIn + fanOut) / 2.0f; | ||
|  |         } else if (mNorm == VarianceNorm::FANOUT) { | ||
|  |             n = fanOut; | ||
|  |         } | ||
|  |         float std = sqrtf(2.0f / n); | ||
|  | 
 | ||
|  |         Distributions::gaussian(count, 0.0f, std, p->writeMap<float>(), RandomGenerator::generator()); | ||
|  |     } | ||
|  | 
 | ||
|  | private: | ||
|  |     VarianceNorm mNorm; | ||
|  | }; | ||
|  | Initializer* Initializer::MSRA(VarianceNorm norm) { | ||
|  |     return new MSRAInitializer(norm); | ||
|  | } | ||
|  | 
 | ||
|  | class BilinearInitializer : public Initializer { | ||
|  | public: | ||
|  |     BilinearInitializer() = default; | ||
|  | 
 | ||
|  |     virtual void onExecute(Express::VARP p) override { | ||
|  |         const int count = p->getInfo()->size; | ||
|  |         MNN_ASSERT(count > 0); | ||
|  |         const std::vector<int> dims = p->getInfo()->dim; | ||
|  |         MNN_ASSERT(dims.size() == 4); | ||
|  |         MNN_ASSERT(dims[2] == dims[3]); // NCHW, H == W
 | ||
|  |         // referenced from Caffe
 | ||
|  |         // https://github.com/BVLC/caffe/blob/master/include/caffe/filler.hpp
 | ||
|  |         int f   = ceilf(dims[3] / 2.0f); | ||
|  |         float c = (dims[3] - 1) / (2.0f * f); | ||
|  |         auto ptr = p->writeMap<float>(); | ||
|  | 
 | ||
|  |         for (int i = 0; i < count; i++) { | ||
|  |             float x                 = i % dims[3]; | ||
|  |             float y                 = (i / dims[3]) % dims[2]; | ||
|  |             ptr[i] = (1 - std::fabs(x / f - c)) * (1 - std::fabs(y / f - c)); | ||
|  |         } | ||
|  |     } | ||
|  | }; | ||
|  | Initializer* Initializer::bilinear() { | ||
|  |     return new BilinearInitializer(); | ||
|  | } | ||
|  | 
 | ||
|  | class PositiveUnitball : public Initializer { | ||
|  | public: | ||
|  |     PositiveUnitball() = default; | ||
|  | 
 | ||
|  |     virtual void onExecute(Express::VARP p) override { | ||
|  |         const int count = p->getInfo()->size; | ||
|  |         MNN_ASSERT(count > 0); | ||
|  |         const std::vector<int> dims = p->getInfo()->dim; | ||
|  |         auto ptr = p->writeMap<float>(); | ||
|  | 
 | ||
|  |         Distributions::uniform(count, 0, 1, ptr, RandomGenerator::generator()); | ||
|  | 
 | ||
|  |         int dim = count / dims[0]; | ||
|  |         for (int i = 0; i < dims[0]; i++) { | ||
|  |             float sum = 0; | ||
|  |             for (int j = 0; j < dim; j++) { | ||
|  |                 sum += ptr[i * dim + j]; | ||
|  |             } | ||
|  |             for (int j = 0; j < dim; j++) { | ||
|  |                 ptr[i * dim + j] = ptr[i * dim + j] / sum; | ||
|  |             } | ||
|  |         } | ||
|  |     } | ||
|  | }; | ||
|  | Initializer* Initializer::positiveUnitball() { | ||
|  |     return new PositiveUnitball(); | ||
|  | } | ||
|  | 
 | ||
|  | } // namespace Express
 | ||
|  | } // namespace MNN
 |