mirror of https://github.com/alibaba/MNN.git
improvement(arm82): optimize arm82 raster op use multi-threads
This commit is contained in:
parent
b0dbe49776
commit
caeab08754
|
|
@ -183,7 +183,6 @@ ErrorCode Arm82Raster::onExecute(const std::vector<Tensor *> &inputs, const std:
|
|||
bytes = 2;
|
||||
}
|
||||
auto threadNum = static_cast<Arm82Backend*>(backend())->numberThread();
|
||||
|
||||
if (mNeedZero) {
|
||||
auto size = bytes;
|
||||
const int dimensions = input->dimensions();
|
||||
|
|
@ -217,48 +216,67 @@ ErrorCode Arm82Raster::onExecute(const std::vector<Tensor *> &inputs, const std:
|
|||
break;
|
||||
}
|
||||
auto byteC4 = bytes * 8;
|
||||
MNN_CONCURRENCY_BEGIN(tId, threadNum) {
|
||||
for (int u=tId; u<mFastBlit.size(); u+=threadNum) {
|
||||
auto& iter = mFastBlit[u];
|
||||
auto& slice = iter.second;
|
||||
//Offset use byte
|
||||
auto srcPtr = (uint8_t*)iter.first + slice.src.offset * bytes;
|
||||
auto dstPtr = (uint8_t*)mOutputPtr + slice.dst.offset * bytes;
|
||||
if (slice.src.stride[1] == slice.size[2] && slice.dst.stride[1] == slice.size[2] && slice.src.stride[2] == 1) {
|
||||
for (int z=0; z<slice.size[0]; ++z) {
|
||||
auto srcZ = srcPtr + z * slice.src.stride[0] * byteC4;
|
||||
auto dstZ = dstPtr + z * slice.dst.stride[0] * byteC4;
|
||||
::memcpy(dstZ, srcZ, slice.size[1] * slice.src.stride[1] * byteC4);
|
||||
|
||||
for (int i = 0; i < mFastBlit.size(); i++) {
|
||||
auto& iter = mFastBlit[i];
|
||||
auto& slice = iter.second;
|
||||
|
||||
//Offset use byte
|
||||
auto srcPtr = (uint8_t*)iter.first + slice.src.offset * bytes;
|
||||
auto dstPtr = (uint8_t*)mOutputPtr + slice.dst.offset * bytes;
|
||||
if (slice.src.stride[1] == slice.size[2] && slice.dst.stride[1] == slice.size[2] && slice.src.stride[2] == 1) {
|
||||
int subPatch = (slice.size[1] * slice.src.stride[1] * byteC4) / threadNum;
|
||||
int extraPatch = slice.size[1] * slice.src.stride[1] * byteC4 - subPatch * threadNum;
|
||||
MNN_CONCURRENCY_BEGIN(tId, threadNum) {
|
||||
for (int z = 0; z < slice.size[0]; ++z) {
|
||||
auto srcZ = srcPtr + subPatch * tId + z * slice.src.stride[0] * byteC4;
|
||||
auto dstZ = dstPtr + subPatch * tId + z * slice.dst.stride[0] * byteC4;
|
||||
::memcpy(dstZ, srcZ, subPatch);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (1 == slice.src.stride[2] && 1 == slice.dst.stride[2]) {
|
||||
for (int z=0; z<slice.size[0]; ++z) {
|
||||
auto srcZ = srcPtr + z * slice.src.stride[0] * byteC4;
|
||||
auto dstZ = dstPtr + z * slice.dst.stride[0] * byteC4;
|
||||
for (int y=0; y<slice.size[1]; ++y) {
|
||||
auto srcY = srcZ + y * slice.src.stride[1] * byteC4;
|
||||
auto dstY = dstZ + y * slice.dst.stride[1] * byteC4;
|
||||
::memcpy(dstY, srcY, slice.size[2] * byteC4);
|
||||
}
|
||||
MNN_CONCURRENCY_END();
|
||||
if (extraPatch > 0) {
|
||||
for (int z = 0; z < slice.size[0]; ++z) {
|
||||
auto srcZ = srcPtr + subPatch * threadNum + z * slice.src.stride[0] * byteC4;
|
||||
auto dstZ = dstPtr + subPatch * threadNum + z * slice.dst.stride[0] * byteC4;
|
||||
::memcpy(dstZ, srcZ, extraPatch);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (1 == slice.src.stride[2] && 1 == slice.dst.stride[2]) {
|
||||
for (int z=0; z<slice.size[0]; ++z) {
|
||||
auto srcZ = srcPtr + z * slice.src.stride[0] * byteC4;
|
||||
auto dstZ = dstPtr + z * slice.dst.stride[0] * byteC4;
|
||||
for (int y=0; y<slice.size[1]; ++y) {
|
||||
MNN_CONCURRENCY_BEGIN(tId, threadNum) {
|
||||
for (int y = tId; y < slice.size[1]; y += threadNum) {
|
||||
auto srcY = srcZ + y * slice.src.stride[1] * byteC4;
|
||||
auto dstY = dstZ + y * slice.dst.stride[1] * byteC4;
|
||||
::memcpy(dstY, srcY, slice.size[2] * byteC4);
|
||||
}
|
||||
}
|
||||
MNN_CONCURRENCY_END();
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
for (int z = 0; z < slice.size[0]; ++z) {
|
||||
auto srcZ = srcPtr + z * slice.src.stride[0] * byteC4;
|
||||
auto dstZ = dstPtr + z * slice.dst.stride[0] * byteC4;
|
||||
MNN_CONCURRENCY_BEGIN(tId, threadNum) {
|
||||
for (int y = tId; y < slice.size[1]; y += threadNum) {
|
||||
auto srcY = srcZ + y * slice.src.stride[1] * byteC4;
|
||||
auto dstY = dstZ + y * slice.dst.stride[1] * byteC4;
|
||||
C4proc(dstY, srcY, slice.size[2], slice.src.stride[2], slice.dst.stride[2]);
|
||||
}
|
||||
}
|
||||
MNN_CONCURRENCY_END();
|
||||
}
|
||||
}
|
||||
MNN_CONCURRENCY_END();
|
||||
|
||||
return NO_ERROR;
|
||||
}
|
||||
|
||||
for (auto& iter : mTempInput) {
|
||||
backend()->onCopyBuffer(iter.first, iter.second.get());
|
||||
}
|
||||
|
|
@ -277,44 +295,61 @@ ErrorCode Arm82Raster::onExecute(const std::vector<Tensor *> &inputs, const std:
|
|||
MNN_ASSERT(false);
|
||||
break;
|
||||
}
|
||||
MNN_CONCURRENCY_BEGIN(tId, threadNum) {
|
||||
for (int u=(int)tId; u<mTempInputCopy.size(); u+=threadNum) {
|
||||
auto& iter = mTempInputCopy[u];
|
||||
auto& slice = *(iter.second);
|
||||
auto srcPtr = (uint8_t*)iter.first + slice.src.offset * bytes;
|
||||
auto dstPtr = (uint8_t*)mOutputPtr + slice.dst.offset * bytes;
|
||||
if (slice.src.stride[1] == slice.size[2] && slice.dst.stride[1] == slice.size[2] && slice.src.stride[2] == 1) {
|
||||
for (int z=0; z<slice.size[0]; ++z) {
|
||||
auto srcZ = srcPtr + z * slice.src.stride[0] * bytes;
|
||||
auto dstZ = dstPtr + z * slice.dst.stride[0] * bytes;
|
||||
::memcpy(dstZ, srcZ, slice.size[1] * slice.src.stride[1] * bytes);
|
||||
|
||||
for (int i = 0; i < mTempInputCopy.size(); i++) {
|
||||
auto& iter = mTempInputCopy[i];
|
||||
auto& slice = *(iter.second);
|
||||
auto srcPtr = (uint8_t*)iter.first + slice.src.offset * bytes;
|
||||
auto dstPtr = (uint8_t*)mOutputPtr + slice.dst.offset * bytes;
|
||||
if (slice.src.stride[1] == slice.size[2] && slice.dst.stride[1] == slice.size[2] && slice.src.stride[2] == 1) {
|
||||
int subPatch = (slice.size[1] * slice.src.stride[1] * bytes) / threadNum;
|
||||
int extraPatch = slice.size[1] * slice.src.stride[1] * bytes - subPatch * threadNum;
|
||||
MNN_CONCURRENCY_BEGIN(tId, threadNum) {
|
||||
for (int z = 0; z < slice.size[0]; ++z) {
|
||||
auto srcZ = srcPtr + subPatch * tId + z * slice.src.stride[0] * bytes;
|
||||
auto dstZ = dstPtr + subPatch * tId + z * slice.dst.stride[0] * bytes;
|
||||
::memcpy(dstZ, srcZ, subPatch);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (1 == slice.src.stride[2] && 1 == slice.dst.stride[2]) {
|
||||
for (int z=0; z<slice.size[0]; ++z) {
|
||||
auto srcZ = srcPtr + z * slice.src.stride[0] * bytes;
|
||||
auto dstZ = dstPtr + z * slice.dst.stride[0] * bytes;
|
||||
for (int y=0; y<slice.size[1]; ++y) {
|
||||
MNN_CONCURRENCY_END();
|
||||
if (extraPatch > 0) {
|
||||
for (int z = 0; z < slice.size[0]; ++z) {
|
||||
auto srcZ = srcPtr + subPatch * threadNum + z * slice.src.stride[0] * bytes;
|
||||
auto dstZ = dstPtr + subPatch * threadNum + z * slice.dst.stride[0] * bytes;
|
||||
::memcpy(dstZ, srcZ, extraPatch);
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (1 == slice.src.stride[2] && 1 == slice.dst.stride[2]) {
|
||||
for (int z = 0; z < slice.size[0]; ++z) {
|
||||
auto srcZ = srcPtr + z * slice.src.stride[0] * bytes;
|
||||
auto dstZ = dstPtr + z * slice.dst.stride[0] * bytes;
|
||||
MNN_CONCURRENCY_BEGIN(tId, threadNum) {
|
||||
for (int y = tId; y < slice.size[1]; y += threadNum) {
|
||||
auto srcY = srcZ + y * slice.src.stride[1] * bytes;
|
||||
auto dstY = dstZ + y * slice.dst.stride[1] * bytes;
|
||||
::memcpy(dstY, srcY, slice.size[2] * bytes);
|
||||
}
|
||||
}
|
||||
continue;
|
||||
MNN_CONCURRENCY_END();
|
||||
}
|
||||
for (int z=0; z<slice.size[0]; ++z) {
|
||||
auto srcZ = srcPtr + z * slice.src.stride[0] * bytes;
|
||||
auto dstZ = dstPtr + (z) * slice.dst.stride[0] * bytes;
|
||||
for (int y=0; y<slice.size[1]; ++y) {
|
||||
continue;
|
||||
}
|
||||
for (int z = 0; z < slice.size[0]; ++z) {
|
||||
auto srcZ = srcPtr + z * slice.src.stride[0] * bytes;
|
||||
auto dstZ = dstPtr + (z) * slice.dst.stride[0] * bytes;
|
||||
MNN_CONCURRENCY_BEGIN(tId, threadNum) {
|
||||
for (int y = tId; y < slice.size[1]; y += threadNum) {
|
||||
auto srcY = srcZ + y * slice.src.stride[1] * bytes;
|
||||
auto dstY = dstZ + y * slice.dst.stride[1] * bytes;
|
||||
proc(dstY, srcY, slice.size[2], slice.src.stride[2], slice.dst.stride[2]);
|
||||
}
|
||||
}
|
||||
MNN_CONCURRENCY_END();
|
||||
}
|
||||
}
|
||||
MNN_CONCURRENCY_END();
|
||||
|
||||
if (nullptr != mTempOutput) {
|
||||
backend()->onCopyBuffer(mTempOutput.get(), output);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue