mirror of https://github.com/alibaba/MNN.git
fix pool shape and onnx slice, add printShape() method
This commit is contained in:
parent
5a4aaffb2a
commit
148353c777
|
|
@ -259,6 +259,11 @@ public:
|
|||
* @brief print tensor data. for DEBUG use only.
|
||||
*/
|
||||
void print() const;
|
||||
|
||||
/**
|
||||
*@brief print tensor shape
|
||||
*/
|
||||
void printShape() const;
|
||||
|
||||
private:
|
||||
halide_buffer_t mBuffer;
|
||||
|
|
|
|||
|
|
@ -416,4 +416,16 @@ void Tensor::print() const {
|
|||
}
|
||||
}
|
||||
|
||||
void Tensor::printShape() const {
|
||||
const int dims = this->dimensions();
|
||||
MNN_PRINT("\t**Tensor shape**: ");
|
||||
if (dims == 0) {
|
||||
MNN_PRINT("\t*Scalar*");
|
||||
}
|
||||
for (int i = 0; i < dims; ++i) {
|
||||
MNN_PRINT("%d, ", this->length(i));
|
||||
}
|
||||
MNN_PRINT("\n");
|
||||
}
|
||||
|
||||
} // namespace MNN
|
||||
|
|
|
|||
|
|
@ -57,7 +57,9 @@ class ConcatSizeComputer : public SizeComputer {
|
|||
}
|
||||
if (t->length(i) != outputs[0]->length(i)) {
|
||||
auto name = op->name() ? op->name()->c_str() : "";
|
||||
MNN_PRINT("Error for concat size of op %s, %d input not match output\n", name, i);
|
||||
MNN_PRINT("Error for concat size of op [ %s ], the %d input not match output\n", name, i);
|
||||
t->printShape();
|
||||
outputs[0]->printShape();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -43,8 +43,8 @@ public:
|
|||
w += layer->padX() * 2;
|
||||
h += layer->padY() * 2;
|
||||
}
|
||||
int kernelWidth = std::min(layer->kernelX(), input->width());
|
||||
int kernelHeight = std::min(layer->kernelY(), input->height());
|
||||
int kernelWidth = std::min(layer->kernelX(), w);
|
||||
int kernelHeight = std::min(layer->kernelY(), h);
|
||||
|
||||
if (layer->padType() == PoolPadType_SAME) { // Tensorflow padding mode SAME
|
||||
outw = ceil((float)w / (float)layer->strideX());
|
||||
|
|
|
|||
|
|
@ -54,6 +54,7 @@ public:
|
|||
auto inputs = expr->inputs();
|
||||
auto op = expr->get();
|
||||
std::unique_ptr<OpT> poolOp(new OpT);
|
||||
poolOp->name = op->name()->c_str();
|
||||
auto extraParam = op->main_as_Extra();
|
||||
bool is3DPooling = false;
|
||||
int attrSize = 0;
|
||||
|
|
|
|||
|
|
@ -132,6 +132,11 @@ public:
|
|||
auto EndVar = MakeConstVecVar(tfEnd);
|
||||
auto StridesVar = MakeConstVecVar(tfStrides);
|
||||
sliceOp->type = OpType_StridedSlice;
|
||||
sliceOp->main.type = OpParameter_StridedSliceParam;
|
||||
auto param = new StridedSliceParamT;
|
||||
param->Index = DataType_DT_INT32;
|
||||
param->T = DataType_DT_FLOAT;
|
||||
sliceOp->main.value = param;
|
||||
return Expr::create(sliceOp.get(), {input, beginVar, EndVar, StridesVar}, expr->outputSize());
|
||||
} else {
|
||||
std::vector<int> tfBegin(ndim, 0);
|
||||
|
|
|
|||
Loading…
Reference in New Issue