diff --git a/tools/converter/source/optimizer/onnxextra/OnnxGather.cpp b/tools/converter/source/optimizer/onnxextra/OnnxGather.cpp index fd5336452..273b29782 100644 --- a/tools/converter/source/optimizer/onnxextra/OnnxGather.cpp +++ b/tools/converter/source/optimizer/onnxextra/OnnxGather.cpp @@ -8,6 +8,7 @@ #include "MNN_generated.h" #include "OnnxExtraManager.hpp" +#include "config.hpp" namespace MNN { namespace Express { @@ -28,7 +29,17 @@ public: } } } - auto output = _GatherV2(inputs[0], inputs[1], _Scalar(axis)); + auto axisVar = _Scalar(axis); + auto config = Global::Get(); + if (config->optimizeLevel < 2) { + // Add negative protect, may decrease performance + auto rankVar = _Rank(inputs[0]); + axisVar = _Select(_GreaterEqual(axisVar, _Scalar(0)), axisVar, axisVar + rankVar); + auto shapeVar = _Shape(inputs[0], true); + auto axisLengthVar = _Squeeze(_StridedSlice(shapeVar, _Unsqueeze(axisVar, {0}), _Unsqueeze(axisVar + _Scalar(1), {0}), _Unsqueeze(_Scalar(1), {0}), 0, 0, 0, 0, 0)); + inputs[1] = _Select(_GreaterEqual(inputs[1], _Scalar(0)), inputs[1], inputs[1] + axisLengthVar); + } + auto output = _GatherV2(inputs[0], inputs[1], axisVar); output->setName(expr->name()); return output->expr().first; }