mirror of https://github.com/alibaba/MNN.git
Fix bug for Gather's indice < 0
This commit is contained in:
parent
a2e1ed4c67
commit
bd9ef418af
|
|
@ -8,6 +8,7 @@
|
|||
|
||||
#include "MNN_generated.h"
|
||||
#include "OnnxExtraManager.hpp"
|
||||
#include "config.hpp"
|
||||
|
||||
namespace MNN {
|
||||
namespace Express {
|
||||
|
|
@ -28,7 +29,17 @@ public:
|
|||
}
|
||||
}
|
||||
}
|
||||
auto output = _GatherV2(inputs[0], inputs[1], _Scalar<int>(axis));
|
||||
auto axisVar = _Scalar<int>(axis);
|
||||
auto config = Global<modelConfig>::Get();
|
||||
if (config->optimizeLevel < 2) {
|
||||
// Add negative protect, may decrease performance
|
||||
auto rankVar = _Rank(inputs[0]);
|
||||
axisVar = _Select(_GreaterEqual(axisVar, _Scalar<int>(0)), axisVar, axisVar + rankVar);
|
||||
auto shapeVar = _Shape(inputs[0], true);
|
||||
auto axisLengthVar = _Squeeze(_StridedSlice(shapeVar, _Unsqueeze(axisVar, {0}), _Unsqueeze(axisVar + _Scalar<int>(1), {0}), _Unsqueeze(_Scalar<int32_t>(1), {0}), 0, 0, 0, 0, 0));
|
||||
inputs[1] = _Select(_GreaterEqual(inputs[1], _Scalar<int>(0)), inputs[1], inputs[1] + axisLengthVar);
|
||||
}
|
||||
auto output = _GatherV2(inputs[0], inputs[1], axisVar);
|
||||
output->setName(expr->name());
|
||||
return output->expr().first;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue