MNN/express/module/FixModule.cpp

53 lines
1.4 KiB
C++
Raw Normal View History

2019-12-27 22:16:57 +08:00
//
// FixModule.cpp
// MNN
//
// Created by MNN on 2019/12/16.
// Copyright © 2018, Alibaba Group Holding Limited
//
#include "FixModule.hpp"
#include <MNN/expr/ExprCreator.hpp>
using namespace MNN::Express;
namespace MNN {
2020-11-05 16:41:56 +08:00
namespace Express {
2019-12-27 22:16:57 +08:00
FixModule::FixModule(std::vector<Express::VARP> output, std::vector<Express::VARP> parameters,
std::vector<std::pair<Express::VARP, Express::Dimensionformat>> inputs) {
for (auto p : parameters) {
addParameter(p);
}
mInputs = std::move(inputs);
mOutput = std::move(output);
}
2020-02-26 09:57:17 +08:00
void FixModule::onClearCache() {
for (auto v : mInputs) {
v.first.fix(VARP::INPUT);
}
}
2019-12-27 22:16:57 +08:00
std::vector<Express::VARP> FixModule::onForward(const std::vector<Express::VARP>& inputs) {
MNN_ASSERT(inputs.size() == mInputs.size());
for (int i = 0; i < inputs.size(); ++i) {
auto var = inputs[i];
var = _Convert(var, mInputs[i].second);
Variable::replace(mInputs[i].first, var);
}
return mOutput;
}
2020-11-05 16:41:56 +08:00
Module* FixModule::clone(CloneContext* ctx) const {
FixModule* module(new FixModule);
for (auto& it : mInputs) {
VARP v = ctx->getOrClone(it.first);
module->mInputs.push_back(std::make_pair(v, it.second));
}
for (auto& it : mOutput) {
VARP v = ctx->getOrClone(it);
module->mOutput.push_back(v);
}
return this->cloneBaseTo(ctx, module);
}
} // namespace Express
2019-12-27 22:16:57 +08:00
} // namespace MNN