mirror of https://github.com/alibaba/MNN.git
				
				
				
			
		
			
				
	
	
		
			108 lines
		
	
	
		
			3.5 KiB
		
	
	
	
		
			C++
		
	
	
	
			
		
		
	
	
			108 lines
		
	
	
		
			3.5 KiB
		
	
	
	
		
			C++
		
	
	
	
| //
 | |
| //  OnnxClip.cpp
 | |
| //  MNNConverter
 | |
| //
 | |
| //  Created by MNN on 2020/06/20.
 | |
| //  Copyright © 2018, Alibaba Group Holding Limited
 | |
| //
 | |
| 
 | |
| #include <limits>
 | |
| #include "MNN_generated.h"
 | |
| #include "OnnxExtraManager.hpp"
 | |
| namespace MNN {
 | |
| namespace Express {
 | |
| 
 | |
| template<typename T> 
 | |
| static EXPRP clipConvert(EXPRP expr, bool supportRelu6) {
 | |
|     auto inputs     = expr->inputs();
 | |
|     auto op         = expr->get();
 | |
|     auto extraParam = op->main_as_Extra();
 | |
|     // auto dataType = expr->outputInfo(0)->type.code;
 | |
|     auto maxValue  = std::numeric_limits<T>().max();
 | |
|     auto minValue  = std::numeric_limits<T>().lowest();
 | |
|     if (nullptr != extraParam->attr()) {
 | |
|         const int attrSize = extraParam->attr()->size();
 | |
|         for (int i = 0; i < attrSize; ++i) {
 | |
|             auto attr       = extraParam->attr()->GetAs<Attribute>(i);
 | |
|             const auto& key = attr->key()->str();
 | |
|             if (key == "max") {
 | |
|                 maxValue = attr->f();
 | |
|             } else if (key == "min") {
 | |
|                 minValue = attr->f();
 | |
|             }
 | |
|         }
 | |
|     }
 | |
|     bool unknown_min_max = false;
 | |
|     if (inputs.size() == 2 || (inputs.size() == 3 && inputs[1].get() != nullptr)) {
 | |
|         auto minPtr = inputs[1]->readMap<T>();
 | |
|         if (nullptr == minPtr) {
 | |
|             unknown_min_max = true;
 | |
|         } else {
 | |
|             minValue = minPtr[0];
 | |
|         }
 | |
|     }
 | |
|     if (inputs.size() == 3 && !unknown_min_max) {
 | |
|         auto maxPtr = inputs[2]->readMap<T>();
 | |
|         if (nullptr == maxPtr) {
 | |
|             unknown_min_max = true;
 | |
|         } else {
 | |
|             maxValue = maxPtr[0];
 | |
|         }
 | |
|     }
 | |
|     if (unknown_min_max || (!supportRelu6)) {
 | |
|         auto minVar = _Scalar<T>(minValue);
 | |
|         auto maxVar = _Scalar<T>(maxValue);
 | |
|         if (inputs.size() >= 2 && inputs[1].get() != nullptr) {
 | |
|             minVar = inputs[1];
 | |
|         }
 | |
|         if (inputs.size() >= 3) {
 | |
|             maxVar = inputs[2];
 | |
|         }
 | |
|         auto res = _Minimum(_Maximum(inputs[0], minVar), maxVar);
 | |
|         auto newExpr = res->expr().first;
 | |
|         newExpr->setName(expr->name());
 | |
|         return newExpr;
 | |
|     }
 | |
|     if(maxValue > std::numeric_limits<T>::max()) {
 | |
|         maxValue = std::numeric_limits<T>().max();
 | |
|     }
 | |
|     if(minValue < std::numeric_limits<T>::lowest()) {
 | |
|         minValue = std::numeric_limits<T>().lowest();
 | |
|     }
 | |
|     std::unique_ptr<OpT> newOp(new OpT);
 | |
|     newOp->type                     = OpType_ReLU6;
 | |
|     newOp->main.type                = OpParameter_Relu6;
 | |
|     newOp->main.value               = new Relu6T;
 | |
|     newOp->main.AsRelu6()->maxValue = maxValue;
 | |
|     newOp->main.AsRelu6()->minValue = minValue;
 | |
|     auto res = Expr::create(newOp.get(), {inputs[0]});
 | |
|     res->setName(expr->name());
 | |
|     return res;
 | |
| }
 | |
| 
 | |
| class OnnxClipTransform : public OnnxExtraManager::Transform {
 | |
| public:
 | |
|     virtual EXPRP onExecute(EXPRP expr) const override {
 | |
|         auto inputs = expr->inputs();
 | |
|         halide_type_code_t type = halide_type_int;
 | |
|         for (int i = 0; i < inputs.size(); ++i) {
 | |
|             if (nullptr != inputs[i] && nullptr != inputs[i]->getInfo()) {
 | |
|                 type = static_cast<halide_type_code_t>(inputs[i]->getInfo()->type.code);
 | |
|                 break;
 | |
|             }
 | |
|         }
 | |
|         if (type == halide_type_float || inputs.size() == 1) {
 | |
|             return clipConvert<float>(expr, true);
 | |
|         }
 | |
|         return clipConvert<int32_t>(expr, false);
 | |
|     }
 | |
| };
 | |
| 
 | |
| static auto gRegister = []() {
 | |
|     OnnxExtraManager::get()->insert("Clip", std::shared_ptr<OnnxExtraManager::Transform>(new OnnxClipTransform));
 | |
|     return true;
 | |
| }();
 | |
| 
 | |
| } // namespace Express
 | |
| } // namespace MNN
 |