mirror of https://github.com/alibaba/MNN.git
117 lines
4.5 KiB
C++
117 lines
4.5 KiB
C++
|
//
|
||
|
// GeometryDepthToSpace.cpp
|
||
|
// MNN
|
||
|
//
|
||
|
// Created by MNN on 2020/04/23.
|
||
|
// Copyright © 2018, Alibaba Group Holding Limited
|
||
|
//
|
||
|
|
||
|
#include "geometry/GeometryComputer.hpp"
|
||
|
#include "core/Macro.h"
|
||
|
namespace MNN {
|
||
|
class GeometryDepthToSpace : public GeometryComputer {
|
||
|
public:
|
||
|
virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
|
||
|
Context& context, CommandBuffer& res) const override {
|
||
|
MNN_ASSERT(1 == outputs.size());
|
||
|
MNN_ASSERT(inputs.size() == 1);
|
||
|
const int blockSize = op->main_as_DepthSpaceParam()->blockSize();
|
||
|
auto mode = op->main_as_DepthSpaceParam()->mode();
|
||
|
auto input = inputs[0];
|
||
|
auto output = outputs[0];
|
||
|
auto outputDes = TensorUtils::getDescribe(output);
|
||
|
outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
|
||
|
auto realTensor = input;
|
||
|
// For OpType_SpaceToDepth, swap input and output
|
||
|
if (op->type() == OpType_SpaceToDepth) {
|
||
|
auto temp = output;
|
||
|
output = input;
|
||
|
input = temp;
|
||
|
}
|
||
|
|
||
|
const int inHeight = input->height();
|
||
|
const int inWidth = input->width();
|
||
|
const int inChannel = input->channel();
|
||
|
const int outHeight = output->height();
|
||
|
const int outWidth = output->width();
|
||
|
const int outChannel = output->channel();
|
||
|
// NCHW Stride
|
||
|
int inputStride[4];
|
||
|
int outputStride[4];
|
||
|
if (MNN_DATA_FORMAT_NHWC == outputDes->dimensionFormat) {
|
||
|
inputStride[0] = inWidth * inHeight * inChannel;
|
||
|
inputStride[1] = 1;
|
||
|
inputStride[2] = inWidth * inChannel;
|
||
|
inputStride[3] = inChannel;
|
||
|
|
||
|
outputStride[0] = outWidth * outHeight * outChannel;
|
||
|
outputStride[1] = 1;
|
||
|
outputStride[2] = outWidth * outChannel;
|
||
|
outputStride[3] = outChannel;
|
||
|
} else {
|
||
|
inputStride[0] = inWidth * inHeight * inChannel;
|
||
|
inputStride[1] = inWidth * inHeight;
|
||
|
inputStride[2] = inWidth;
|
||
|
inputStride[3] = 1;
|
||
|
|
||
|
outputStride[0] = outWidth * outHeight * outChannel;
|
||
|
outputStride[1] = outHeight * outWidth;
|
||
|
outputStride[2] = outWidth;
|
||
|
outputStride[3] = 1;
|
||
|
}
|
||
|
auto batch = input->batch();
|
||
|
auto regionSize = blockSize * blockSize * batch;
|
||
|
outputDes->regions.resize(regionSize);
|
||
|
for (int b = 0; b < batch; ++b) {
|
||
|
auto dstB = b * outputStride[0];
|
||
|
auto srcB = b * inputStride[0];
|
||
|
for (int hb = 0; hb < blockSize; ++hb) {
|
||
|
auto dstHB = dstB + hb * outputStride[2];
|
||
|
for (int wb = 0; wb < blockSize; ++wb) {
|
||
|
auto dstWB = dstHB + wb * outputStride[3];
|
||
|
int offsetC = hb * blockSize + wb;
|
||
|
if (mode == DepthToSpaceMode_DCR) {
|
||
|
offsetC *= outChannel;
|
||
|
}
|
||
|
auto srcWB = srcB + offsetC * inputStride[1];
|
||
|
|
||
|
auto& region = outputDes->regions[b * blockSize * blockSize + wb + hb * blockSize];
|
||
|
region.origin = realTensor;
|
||
|
region.size[0] = inHeight;
|
||
|
region.size[1] = inWidth;
|
||
|
region.size[2] = outChannel;
|
||
|
|
||
|
auto srcR = ®ion.src;
|
||
|
auto dstR = ®ion.dst;
|
||
|
if (op->type() == OpType_SpaceToDepth) {
|
||
|
srcR = ®ion.dst;
|
||
|
dstR = ®ion.src;
|
||
|
}
|
||
|
|
||
|
dstR->offset = dstWB;
|
||
|
dstR->stride[0] = outputStride[2] * blockSize;
|
||
|
dstR->stride[1] = outputStride[3] * blockSize;
|
||
|
dstR->stride[2] = outputStride[1];
|
||
|
|
||
|
srcR->offset = srcWB;
|
||
|
srcR->stride[0] = inputStride[2];
|
||
|
srcR->stride[1] = inputStride[3];
|
||
|
srcR->stride[2] = inputStride[1];
|
||
|
if (mode == DepthToSpaceMode_CRD) {
|
||
|
srcR->stride[2] *= (blockSize * blockSize);
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
return true;
|
||
|
}
|
||
|
};
|
||
|
static void _create() {
|
||
|
std::shared_ptr<GeometryComputer> comp(new GeometryDepthToSpace);
|
||
|
GeometryComputer::registerGeometryComputer(comp, {OpType_DepthToSpace, OpType_SpaceToDepth});
|
||
|
}
|
||
|
|
||
|
REGISTER_GEOMETRY(GeometryDepthToSpace, _create);
|
||
|
|
||
|
} // namespace MNN
|