2019-04-17 10:49:11 +08:00
|
|
|
//
|
|
|
|
// ShapeSpaceToBatchND.cpp
|
|
|
|
// MNN
|
|
|
|
//
|
|
|
|
// Created by MNN on 2019/01/10.
|
|
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
|
|
//
|
|
|
|
|
2019-12-27 22:16:57 +08:00
|
|
|
#include "core/SizeComputer.hpp"
|
|
|
|
#include "core/TensorUtils.hpp"
|
2019-04-17 10:49:11 +08:00
|
|
|
|
|
|
|
namespace MNN {
|
|
|
|
class SpaceToBatchNDSizeComputer : public SizeComputer {
|
|
|
|
public:
|
|
|
|
virtual bool onComputeSize(const MNN::Op* op, const std::vector<Tensor*>& inputs,
|
|
|
|
const std::vector<Tensor*>& outputs) const override {
|
|
|
|
auto input = inputs[0];
|
|
|
|
auto output = outputs[0];
|
|
|
|
|
|
|
|
auto paramter = op->main_as_SpaceBatch();
|
|
|
|
const auto blockShape = paramter->blockShape();
|
|
|
|
int batch = input->batch();
|
|
|
|
for (int i = 0; i < blockShape->dims()->data()[0]; ++i) {
|
|
|
|
batch *= blockShape->int32s()->data()[i];
|
|
|
|
}
|
|
|
|
|
|
|
|
const auto paddings = paramter->padding();
|
|
|
|
const auto paddingData = paddings->int32s()->data();
|
|
|
|
int paddedHeight = input->height() + paddingData[0] + paddingData[1];
|
|
|
|
int paddedWidth = input->width() + paddingData[2] + paddingData[3];
|
|
|
|
int outputHeight = paddedHeight / blockShape->int32s()->data()[0];
|
|
|
|
int outputWidth = paddedWidth / blockShape->int32s()->data()[1];
|
2020-02-26 09:57:17 +08:00
|
|
|
output->buffer().type = input->buffer().type;
|
2019-04-17 10:49:11 +08:00
|
|
|
output->buffer().dimensions = input->buffer().dimensions;
|
|
|
|
output->setLength(0, batch);
|
|
|
|
output->setLength(1, input->channel());
|
|
|
|
output->setLength(2, outputHeight);
|
|
|
|
output->setLength(3, outputWidth);
|
2019-08-22 20:13:46 +08:00
|
|
|
TensorUtils::getDescribe(output)->dimensionFormat = MNN_DATA_FORMAT_NC4HW4;
|
2019-04-17 10:49:11 +08:00
|
|
|
return true;
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
REGISTER_SHAPE(SpaceToBatchNDSizeComputer, OpType_SpaceToBatchND);
|
|
|
|
} // namespace MNN
|