mirror of https://github.com/alibaba/MNN.git
56 lines
1.7 KiB
C++
56 lines
1.7 KiB
C++
//
|
|
// SelectGrad.cpp
|
|
// MNN
|
|
//
|
|
// Created by MNN on 2019/05/26.
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
//
|
|
|
|
#include "SelectGrad.hpp"
|
|
#include "core/Macro.h"
|
|
using namespace std;
|
|
using namespace MNN;
|
|
using namespace MNN::Express;
|
|
|
|
class SelectGrad : public OpGrad {
|
|
public:
|
|
SelectGrad() {
|
|
mType = SEMI_LINEAR;
|
|
}
|
|
virtual std::vector<Express::VARP> onGrad(Express::EXPRP expr, const std::vector<Express::VARP>& output,
|
|
const std::vector<Express::VARP>& backwardOutput) override {
|
|
auto inputs = expr->inputs();
|
|
std::vector<VARP> result(inputs.size(), nullptr);
|
|
auto outputDiff = backwardOutput[0];
|
|
// d (select(x, a, b)) = da * (x>0) + db * (x < 0)
|
|
{
|
|
// Cast x>0 -> float
|
|
unique_ptr<OpT> mask(new OpT);
|
|
mask->type = OpType_Cast;
|
|
mask->main.type = OpParameter_CastParam;
|
|
mask->main.value = new CastParamT;
|
|
mask->main.AsCastParam()->dstT = DataType_DT_FLOAT;
|
|
mask->main.AsCastParam()->srcT = DataType_DT_BOOL;
|
|
|
|
auto maskVar = Variable::create(Expr::create(std::move(mask), {inputs[0]}));
|
|
|
|
// da * (x>0)
|
|
result[1] = _Multiply(outputDiff, maskVar);
|
|
|
|
// db * -((x>0)-1)
|
|
auto one = _Const(1.0f);
|
|
auto sub = _Subtract(maskVar, one);
|
|
auto neg = _Negative(sub);
|
|
result[2] = _Multiply(outputDiff, neg);
|
|
}
|
|
|
|
return result;
|
|
}
|
|
};
|
|
|
|
static const auto gRegister = []() {
|
|
static SelectGrad _c;
|
|
OpGrad::insert(OpType_Select, &_c);
|
|
return true;
|
|
}();
|