mirror of https://github.com/alibaba/MNN.git
69 lines
2.7 KiB
C++
69 lines
2.7 KiB
C++
//
|
|
// ShapeRNNSequenceGRU.cpp
|
|
// MNN
|
|
//
|
|
// Created by MNN on 2019/03/19.
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
//
|
|
|
|
#include "shape/SizeComputer.hpp"
|
|
#include "core/TensorUtils.hpp"
|
|
namespace MNN {
|
|
|
|
class RNNSequenceGRUComputer : public SizeComputer {
|
|
public:
|
|
virtual bool onComputeSize(const MNN::Op* op, const std::vector<Tensor*>& inputs,
|
|
const std::vector<Tensor*>& outputs) const override {
|
|
MNN_ASSERT(1 == inputs.size());
|
|
MNN_ASSERT(1 <= outputs.size());
|
|
|
|
auto input = inputs[0];
|
|
auto output = outputs[0];
|
|
MNN_ASSERT(3 == input->dimensions());
|
|
|
|
const auto rnnParam = op->main_as_RNNParam();
|
|
const int numUnits = rnnParam->numUnits();
|
|
bool keepAllOuptuts = rnnParam->keepAllOutputs();
|
|
bool isBidirectionalRNN = rnnParam->isBidirectionalRNN();
|
|
MNN_ASSERT(2 == rnnParam->fwGateWeight()->dims()->size());
|
|
MNN_ASSERT(2 * numUnits == rnnParam->fwGateWeight()->dims()->data()[1]);
|
|
output->buffer().type = halide_type_of<float>();
|
|
TensorUtils::getDescribe(output)->dimensionFormat = TensorUtils::getDescribe(input)->dimensionFormat;
|
|
MNN_ASSERT((input->length(2) + numUnits) == rnnParam->fwGateWeight()->dims()->data()[0]);
|
|
if (keepAllOuptuts) {
|
|
TensorUtils::copyShape(input, output);
|
|
output->setLength(2, rnnParam->numUnits());
|
|
output->buffer().type = input->buffer().type;
|
|
|
|
if (isBidirectionalRNN) {
|
|
MNN_ASSERT(2 == outputs.size());
|
|
auto outputBW = outputs[1];
|
|
TensorUtils::copyShape(input, outputBW);
|
|
outputBW->setLength(2, rnnParam->numUnits());
|
|
outputBW->buffer().type = input->buffer().type;
|
|
}
|
|
} else {
|
|
auto& inputBuffer = input->buffer();
|
|
auto& outputBuffer = output->buffer();
|
|
outputBuffer.dimensions = 2;
|
|
outputBuffer.dim[0].extent = inputBuffer.dim[0].extent;
|
|
outputBuffer.dim[1].extent = rnnParam->numUnits();
|
|
outputBuffer.type = inputBuffer.type;
|
|
|
|
if (isBidirectionalRNN) {
|
|
MNN_ASSERT(2 == outputs.size());
|
|
auto outputBW = outputs[1];
|
|
auto& outputBWBuffer = outputBW->buffer();
|
|
outputBWBuffer.dimensions = 2;
|
|
outputBWBuffer.dim[0].extent = inputBuffer.dim[0].extent;
|
|
outputBWBuffer.dim[1].extent = rnnParam->numUnits();
|
|
outputBWBuffer.type = inputBuffer.type;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
};
|
|
|
|
REGISTER_SHAPE(RNNSequenceGRUComputer, OpType_RNNSequenceGRU);
|
|
} // namespace MNN
|