MNN/tools/train/source/grad/ReluGrad.cpp

89 lines
3.4 KiB
C++

//
// ReluGrad.cpp
// MNN
//
// Created by MNN on 2019/04/22.
// Copyright © 2018, Alibaba Group Holding Limited
//
#include "ReluGrad.hpp"
#include "core/Macro.h"
#include <string.h>
using namespace std;
using namespace MNN;
using namespace MNN::Express;
class PReluGrad : public OpGrad {
public:
virtual std::vector<Express::VARP> onGrad(Express::EXPRP expr,
const std::vector<Express::VARP>& backwardOutput) override {
std::vector<Express::VARP> result(1, nullptr);
auto op = expr->get();
auto input = expr->inputs()[0];
auto mask = (_Sign(input) + _Scalar<float>(1.0f)) * _Scalar<float>(0.5f);
auto prelu = op->main_as_PRelu();
if (prelu->slope()->size() == 1) {
auto slope = prelu->slope()->data()[0];
result[0] = (mask + (_Scalar<float>(1.0f) - mask) * _Scalar<float>(slope)) * backwardOutput[0];
return result;
}
auto channel = prelu->slope()->size();
std::vector<float> scale(channel);
::memcpy(scale.data(), prelu->slope()->data(), channel * sizeof(float));
std::vector<float> bias(channel, 0.0f);
auto outputSecond = _Scale(backwardOutput[0], channel, std::move(scale), std::move(bias));
result[0] = mask * backwardOutput[0] + (_Scalar<float>(1.0f) - mask) * outputSecond;
// auto diffInfo = result[0]->getInfo();
// auto inputInfo = input->getInfo();
// for (int i=0; i<diffInfo->dim.size(); ++i) {
// MNN_ASSERT(diffInfo->dim[i] == inputInfo->dim[i]);
// MNN_PRINT("%s, %d, %d - %d\n", expr->name().c_str(), i, diffInfo->dim[i], inputInfo->dim[i]);
// }
// MNN_ASSERT(diffInfo->order == inputInfo->order);
return result;
}
};
class ReluGrad : public OpGrad {
public:
ReluGrad() {
mType = SEMI_LINEAR;
}
virtual std::vector<Express::VARP> onGrad(Express::EXPRP expr,
const std::vector<Express::VARP>& backwardOutput) override {
std::vector<Express::VARP> result(1, nullptr);
auto op = expr->get();
auto input = expr->inputs()[0];
auto mask = (_Sign(input) + _Scalar<float>(1.0f)) * _Scalar<float>(0.5f);
if (nullptr != op->main_as_Relu() && op->main_as_Relu()->slope() != 0.0f) {
result[0] = (mask + (_Scalar<float>(1.0f) - mask) * _Scalar<float>(op->main_as_Relu()->slope())) * backwardOutput[0];
return result;
}
result[0] = mask * backwardOutput[0];
return result;
}
};
class Relu6Grad : public OpGrad {
public:
Relu6Grad() {
mType = SEMI_LINEAR;
}
virtual std::vector<Express::VARP> onGrad(Express::EXPRP expr,
const std::vector<Express::VARP>& backwardOutput) override {
std::vector<Express::VARP> result{nullptr};
auto input = expr->inputs()[0];
auto mask0 = (_Sign(input) + _Scalar<float>(1.0f));
auto mask1 = (_Sign(_Scalar<float>(6.0f) - input) + _Scalar<float>(1.0f));
result[0] = mask0 * mask1 * backwardOutput[0] * _Scalar<float>(0.25f);
return result;
}
};
static const auto gRegister = []() {
static ReluGrad _c;
OpGrad::insert(OpType_ReLU, &_c);
static Relu6Grad _d;
OpGrad::insert(OpType_ReLU6, &_d);
static PReluGrad _e;
OpGrad::insert(OpType_PReLU, &_e);
return true;
}();