Fix bug for Gather's indice < 0

This commit is contained in:
xiaying 2021-12-07 12:54:12 +08:00
parent a2e1ed4c67
commit bd9ef418af
1 changed files with 12 additions and 1 deletions

View File

@ -8,6 +8,7 @@
#include "MNN_generated.h"
#include "OnnxExtraManager.hpp"
#include "config.hpp"
namespace MNN {
namespace Express {
@ -28,7 +29,17 @@ public:
}
}
}
auto output = _GatherV2(inputs[0], inputs[1], _Scalar<int>(axis));
auto axisVar = _Scalar<int>(axis);
auto config = Global<modelConfig>::Get();
if (config->optimizeLevel < 2) {
// Add negative protect, may decrease performance
auto rankVar = _Rank(inputs[0]);
axisVar = _Select(_GreaterEqual(axisVar, _Scalar<int>(0)), axisVar, axisVar + rankVar);
auto shapeVar = _Shape(inputs[0], true);
auto axisLengthVar = _Squeeze(_StridedSlice(shapeVar, _Unsqueeze(axisVar, {0}), _Unsqueeze(axisVar + _Scalar<int>(1), {0}), _Unsqueeze(_Scalar<int32_t>(1), {0}), 0, 0, 0, 0, 0));
inputs[1] = _Select(_GreaterEqual(inputs[1], _Scalar<int>(0)), inputs[1], inputs[1] + axisLengthVar);
}
auto output = _GatherV2(inputs[0], inputs[1], axisVar);
output->setName(expr->name());
return output->expr().first;
}