Merge pull request #3765 from alibaba/feautre/bugfix
android / android_build (push) Has been cancelled Details
ios / ios_build (push) Has been cancelled Details
linux / linux_buil_test (push) Has been cancelled Details
macos / macos_buil_test (push) Has been cancelled Details
windows / windows_build_test (push) Has been cancelled Details

MNN:Bugfix: Fix bug for GeometryReverse don't clear origin region
This commit is contained in:
jxt1234 2025-07-30 12:47:47 +08:00 committed by GitHub
commit ba68171bbe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 32 additions and 7 deletions

View File

@ -576,6 +576,9 @@ static void _blit(const Tensor::InsideDescribe::Region& slice, int bytes, const
for (int y=0; y<slice.size[1]; ++y) {
auto srcY = srcZ + y * slice.src.stride[1] * bytes;
auto dstY = dstZ + y * slice.dst.stride[1] * bytes;
#ifdef DEBUG
::memset(dstY, 0, slice.size[2] * bytes);
#endif
::memcpy(dstY, srcY, slice.size[2] * bytes);
}
}

View File

@ -176,7 +176,7 @@ public:
region.dst.stride[0] = reverseSize*insideSize;
region.dst.stride[1] = insideSize;
region.dst.stride[2] = 1;
outputDes->regions.emplace_back(std::move(region));
outputDes->regions = {std::move(region)};
return true;
}

View File

@ -8,6 +8,7 @@
#include <MNN/expr/Expr.hpp>
#include <MNN/expr/ExprCreator.hpp>
#include <MNN/expr/Module.hpp>
#include "MNNTestSuite.h"
#include "TestUtils.h"
@ -16,6 +17,19 @@ class ReverseTest : public MNNTestCase {
public:
virtual ~ReverseTest() = default;
virtual bool run(int precision) {
std::shared_ptr<MNN::Express::Module> net;
{
auto input = _Input({3, 2, 3}, NCHW);
input->setName("i");
auto output0 = _Reverse(input, _Scalar<int32_t>(0));
auto output1 = _Reverse(input, _Scalar<int32_t>(1));
auto output2 = _Reverse(input, _Scalar<int32_t>(2));
output0->setName("o0");
output1->setName("o1");
output2->setName("o2");
auto buffer = Variable::save({output0, output1, output2});
net.reset(MNN::Express::Module::load({"i"}, {"o0", "o1", "o2"}, (uint8_t*)buffer.data(), buffer.size()));
}
{
auto input = _Input({3, 2, 3}, NCHW);
input->setName("input_tensor");
@ -25,7 +39,9 @@ public:
13, 14, 15, 16, 17, 18 };
auto inputPtr = input->writeMap<float>();
memcpy(inputPtr, inpudata, 18 * sizeof(float));
auto output0 = _Reverse(input, _Scalar<int32_t>(0));
auto outputs = net->onForward({input});
auto output0 = outputs[0];
const std::vector<float> expectedOutput0 = { 13, 14, 15, 16, 17, 18,
7, 8, 9, 10, 11, 12,
1, 2, 3, 4, 5, 6 };
@ -37,7 +53,7 @@ public:
return false;
}
}
auto output1 = _Reverse(input, _Scalar<int32_t>(1));
auto output1 = outputs[1];
const std::vector<float> expectedOutput1 = { 4, 5, 6, 1, 2, 3,
10, 11, 12, 7, 8, 9,
16, 17, 18, 13, 14, 15 };
@ -49,7 +65,7 @@ public:
return false;
}
}
auto output2 = _Reverse(input, _Scalar<int32_t>(2));
auto output2 = outputs[2];
const std::vector<float> expectedOutput2 = { 3, 2, 1, 6, 5, 4,
9, 8, 7, 12, 11, 10,
15, 14, 13, 18, 17, 16 };
@ -73,7 +89,8 @@ public:
13, 14, 15, 16 };
auto inputPtr = input->writeMap<uint8_t>();
memcpy(inputPtr, inpudata, 16 * sizeof(uint8_t));
auto output0 = _Reverse(input, _Scalar<int32_t>(0));
auto outputs = net->onForward({input});
auto output0 = outputs[0];
const std::vector<uint8_t> expectedOutput0 = {
9, 10, 11, 12, 13, 14, 15, 16,
1, 2, 3, 4, 5, 6, 7, 8 };
@ -85,7 +102,7 @@ public:
return false;
}
}
auto output1 = _Reverse(input, _Scalar<int32_t>(1));
auto output1 = outputs[1];
const std::vector<uint8_t> expectedOutput1 = {5, 6, 7, 8, 1, 2, 3, 4,
13, 14, 15, 16, 9, 10, 11, 12,
};
@ -97,7 +114,7 @@ public:
return false;
}
}
auto output2 = _Reverse(input, _Scalar<int32_t>(2));
auto output2 = outputs[2];
const std::vector<uint8_t> expectedOutput2 = { 3, 4, 1, 2, 7, 8, 5, 6,
11, 12, 9, 10, 15, 16, 13, 14 };
auto gotOutput2 = output2->readMap<uint8_t>();
@ -119,6 +136,11 @@ public:
return false;
}
}
input->resize({1, 1, 1, 1});
input->writeMap<uint8_t>()[0] = 74;
auto c0 = output1->readMap<uint8_t>()[0];
auto c1 = output2->readMap<uint8_t>()[0];
auto c2 = output3->readMap<uint8_t>()[0];
}
{ // test SizeComputer::needInputContent