mirror of https://github.com/alibaba/MNN.git
52 lines
1.6 KiB
C++
52 lines
1.6 KiB
C++
//
|
|
// NPUSlice.cpp
|
|
// MNN
|
|
//
|
|
// Created by MNN on b'2020/10/15'.
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
//
|
|
|
|
#include "NPUSlice.hpp"
|
|
#include "NPUBackend.hpp"
|
|
|
|
using namespace std;
|
|
|
|
namespace MNN {
|
|
|
|
NPUSlice::NPUSlice(MNN::Backend *b, const MNN::Op *op, const std::vector<Tensor *> &inputs, const std::vector<MNN::Tensor *> &outputs) : NPUCommonExecution(b, op) {}
|
|
|
|
ErrorCode NPUSlice::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
|
|
mNpuBackend->setNetworkInput(inputs, mOp);
|
|
|
|
auto opName = mOp->name()->str();
|
|
shared_ptr<hiai::op::SplitD> slice(new hiai::op::SplitD(opName));
|
|
|
|
auto param = mOp->main_as_Slice();
|
|
auto axis = param->axis();
|
|
int64_t slice_num = 0;
|
|
if (param->slicePoints() != nullptr) {
|
|
if (param->slicePoints()->size() < outputs.size()) {
|
|
slice_num = static_cast<int64_t>(outputs.size());
|
|
} else if (param->slicePoints()->size() == 1) {
|
|
slice_num = static_cast<int64_t>(param->slicePoints()->Get(0));
|
|
} else {
|
|
slice_num = static_cast<int64_t>(param->slicePoints()->size());
|
|
}
|
|
} else {
|
|
slice_num = static_cast<int64_t>(outputs.size());
|
|
}
|
|
auto xOp = mNpuBackend->getInputOps(mOp);
|
|
|
|
(*slice)
|
|
.set_input_x(*xOp.get())
|
|
.set_attr_split_dim(axis)
|
|
.set_attr_num_split(slice_num)
|
|
.create_dynamic_output_y(slice_num);
|
|
|
|
mNpuBackend->setOutputOps(mOp, {slice}, outputs);
|
|
return NO_ERROR;
|
|
}
|
|
|
|
NPUCreatorRegister<TypedCreator<NPUSlice>> __slice_op(OpType_Slice);
|
|
|
|
} // namespace MNN
|