2020-07-01 19:06:31 +08:00
|
|
|
//
|
|
|
|
|
// OnnxClip.cpp
|
|
|
|
|
// MNNConverter
|
|
|
|
|
//
|
|
|
|
|
// Created by MNN on 2020/06/20.
|
|
|
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
|
|
|
//
|
|
|
|
|
|
2020-11-05 16:41:56 +08:00
|
|
|
#include <limits>
|
2020-07-01 19:06:31 +08:00
|
|
|
#include "MNN_generated.h"
|
|
|
|
|
#include "OnnxExtraManager.hpp"
|
|
|
|
|
namespace MNN {
|
|
|
|
|
namespace Express {
|
|
|
|
|
|
2023-12-04 11:12:20 +08:00
|
|
|
template<typename T>
|
2024-08-24 15:46:21 +08:00
|
|
|
static EXPRP clipConvert(EXPRP expr, bool supportRelu6) {
|
2023-12-04 11:12:20 +08:00
|
|
|
auto inputs = expr->inputs();
|
|
|
|
|
auto op = expr->get();
|
|
|
|
|
auto extraParam = op->main_as_Extra();
|
|
|
|
|
// auto dataType = expr->outputInfo(0)->type.code;
|
|
|
|
|
auto maxValue = std::numeric_limits<T>().max();
|
2025-08-22 18:04:08 +08:00
|
|
|
auto minValue = std::numeric_limits<T>().lowest();
|
2023-12-04 11:12:20 +08:00
|
|
|
if (nullptr != extraParam->attr()) {
|
|
|
|
|
const int attrSize = extraParam->attr()->size();
|
|
|
|
|
for (int i = 0; i < attrSize; ++i) {
|
|
|
|
|
auto attr = extraParam->attr()->GetAs<Attribute>(i);
|
|
|
|
|
const auto& key = attr->key()->str();
|
|
|
|
|
if (key == "max") {
|
|
|
|
|
maxValue = attr->f();
|
|
|
|
|
} else if (key == "min") {
|
|
|
|
|
minValue = attr->f();
|
2020-07-01 19:06:31 +08:00
|
|
|
}
|
|
|
|
|
}
|
2023-12-04 11:12:20 +08:00
|
|
|
}
|
|
|
|
|
bool unknown_min_max = false;
|
|
|
|
|
if (inputs.size() == 2 || (inputs.size() == 3 && inputs[1].get() != nullptr)) {
|
|
|
|
|
auto minPtr = inputs[1]->readMap<T>();
|
|
|
|
|
if (nullptr == minPtr) {
|
|
|
|
|
unknown_min_max = true;
|
|
|
|
|
} else {
|
|
|
|
|
minValue = minPtr[0];
|
2021-04-28 18:02:10 +08:00
|
|
|
}
|
2023-12-04 11:12:20 +08:00
|
|
|
}
|
|
|
|
|
if (inputs.size() == 3 && !unknown_min_max) {
|
|
|
|
|
auto maxPtr = inputs[2]->readMap<T>();
|
|
|
|
|
if (nullptr == maxPtr) {
|
|
|
|
|
unknown_min_max = true;
|
|
|
|
|
} else {
|
|
|
|
|
maxValue = maxPtr[0];
|
2022-01-04 10:50:40 +08:00
|
|
|
}
|
2023-12-04 11:12:20 +08:00
|
|
|
}
|
2024-08-24 15:46:21 +08:00
|
|
|
if (unknown_min_max || (!supportRelu6)) {
|
2023-12-04 11:12:20 +08:00
|
|
|
auto minVar = _Scalar<T>(minValue);
|
|
|
|
|
auto maxVar = _Scalar<T>(maxValue);
|
|
|
|
|
if (inputs.size() >= 2 && inputs[1].get() != nullptr) {
|
|
|
|
|
minVar = inputs[1];
|
|
|
|
|
}
|
|
|
|
|
if (inputs.size() >= 3) {
|
|
|
|
|
maxVar = inputs[2];
|
|
|
|
|
}
|
|
|
|
|
auto res = _Minimum(_Maximum(inputs[0], minVar), maxVar);
|
|
|
|
|
auto newExpr = res->expr().first;
|
|
|
|
|
newExpr->setName(expr->name());
|
|
|
|
|
return newExpr;
|
|
|
|
|
}
|
2024-02-29 16:21:40 +08:00
|
|
|
if(maxValue > std::numeric_limits<T>::max()) {
|
|
|
|
|
maxValue = std::numeric_limits<T>().max();
|
|
|
|
|
}
|
|
|
|
|
if(minValue < std::numeric_limits<T>::lowest()) {
|
|
|
|
|
minValue = std::numeric_limits<T>().lowest();
|
|
|
|
|
}
|
2023-12-04 11:12:20 +08:00
|
|
|
std::unique_ptr<OpT> newOp(new OpT);
|
|
|
|
|
newOp->type = OpType_ReLU6;
|
|
|
|
|
newOp->main.type = OpParameter_Relu6;
|
|
|
|
|
newOp->main.value = new Relu6T;
|
|
|
|
|
newOp->main.AsRelu6()->maxValue = maxValue;
|
|
|
|
|
newOp->main.AsRelu6()->minValue = minValue;
|
|
|
|
|
auto res = Expr::create(newOp.get(), {inputs[0]});
|
|
|
|
|
res->setName(expr->name());
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
class OnnxClipTransform : public OnnxExtraManager::Transform {
|
|
|
|
|
public:
|
|
|
|
|
virtual EXPRP onExecute(EXPRP expr) const override {
|
|
|
|
|
auto inputs = expr->inputs();
|
2024-08-24 15:46:21 +08:00
|
|
|
halide_type_code_t type = halide_type_int;
|
2023-12-04 11:12:20 +08:00
|
|
|
for (int i = 0; i < inputs.size(); ++i) {
|
|
|
|
|
if (nullptr != inputs[i] && nullptr != inputs[i]->getInfo()) {
|
2023-12-04 16:40:47 +08:00
|
|
|
type = static_cast<halide_type_code_t>(inputs[i]->getInfo()->type.code);
|
2023-12-04 11:12:20 +08:00
|
|
|
break;
|
2022-02-18 11:30:27 +08:00
|
|
|
}
|
2020-07-01 19:06:31 +08:00
|
|
|
}
|
2024-08-24 15:46:21 +08:00
|
|
|
if (type == halide_type_float || inputs.size() == 1) {
|
|
|
|
|
return clipConvert<float>(expr, true);
|
2023-12-04 11:12:20 +08:00
|
|
|
}
|
2024-08-24 15:46:21 +08:00
|
|
|
return clipConvert<int32_t>(expr, false);
|
2020-07-01 19:06:31 +08:00
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
static auto gRegister = []() {
|
|
|
|
|
OnnxExtraManager::get()->insert("Clip", std::shared_ptr<OnnxExtraManager::Transform>(new OnnxClipTransform));
|
|
|
|
|
return true;
|
|
|
|
|
}();
|
|
|
|
|
|
|
|
|
|
} // namespace Express
|
|
|
|
|
} // namespace MNN
|