2019-12-27 22:16:57 +08:00
|
|
|
//
|
|
|
|
// CPUUnravelIndex.cpp
|
|
|
|
// MNN
|
|
|
|
//
|
|
|
|
// Created by MNN on 2019/11/26.
|
|
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
|
|
//
|
|
|
|
|
|
|
|
#include "backend/cpu/CPUUnravelIndex.hpp"
|
|
|
|
#include "backend/cpu/CPUBackend.hpp"
|
2020-11-05 16:41:56 +08:00
|
|
|
#include "core/OpCommonUtils.hpp"
|
2019-12-27 22:16:57 +08:00
|
|
|
|
|
|
|
namespace MNN {
|
|
|
|
|
|
|
|
ErrorCode CPUUnravelIndex::onExecute(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
|
|
|
|
auto indices = inputs[0];
|
|
|
|
auto dims = inputs[1];
|
|
|
|
|
|
|
|
const int elmentSize = indices->elementSize();
|
|
|
|
const int dimsSize = dims->length(0);
|
|
|
|
|
|
|
|
const auto indicesPtr = indices->host<int32_t>();
|
|
|
|
const auto dimsDataPtr = dims->host<int32_t>();
|
2021-11-30 10:10:53 +08:00
|
|
|
int mod[MNN_MAX_TENSOR_DIM];
|
|
|
|
OpCommonUtils::computeStride(mod, dimsDataPtr, dimsSize);
|
2019-12-27 22:16:57 +08:00
|
|
|
auto outputDataPtr = outputs[0]->host<int32_t>();
|
|
|
|
|
2021-11-30 10:10:53 +08:00
|
|
|
int coordinate[MNN_MAX_TENSOR_DIM];
|
2019-12-27 22:16:57 +08:00
|
|
|
for (int i = 0; i < elmentSize; ++i) {
|
2020-11-05 16:41:56 +08:00
|
|
|
OpCommonUtils::unravelIndexHelper(coordinate, mod, dimsSize, indicesPtr[i]);
|
2019-12-27 22:16:57 +08:00
|
|
|
// assign value
|
|
|
|
for (int k = 0; k < dimsSize; ++k) {
|
|
|
|
outputDataPtr[i + k * elmentSize] = coordinate[k];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return NO_ERROR;
|
|
|
|
}
|
|
|
|
|
|
|
|
class CPUUnravelIndexCreator : public CPUBackend::Creator {
|
|
|
|
public:
|
|
|
|
virtual Execution* onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
|
|
|
|
const MNN::Op* op, Backend* backend) const override {
|
|
|
|
return new CPUUnravelIndex(backend);
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
REGISTER_CPU_OP_CREATOR(CPUUnravelIndexCreator, OpType_UnravelIndex);
|
|
|
|
|
|
|
|
} // namespace MNN
|