improvement(arm82): optimize arm82 raster op use multi-threads

This commit is contained in:
gaoruidong 2020-12-23 17:47:34 +08:00
parent b0dbe49776
commit caeab08754
1 changed files with 85 additions and 50 deletions

View File

@ -183,7 +183,6 @@ ErrorCode Arm82Raster::onExecute(const std::vector<Tensor *> &inputs, const std:
bytes = 2;
}
auto threadNum = static_cast<Arm82Backend*>(backend())->numberThread();
if (mNeedZero) {
auto size = bytes;
const int dimensions = input->dimensions();
@ -217,48 +216,67 @@ ErrorCode Arm82Raster::onExecute(const std::vector<Tensor *> &inputs, const std:
break;
}
auto byteC4 = bytes * 8;
MNN_CONCURRENCY_BEGIN(tId, threadNum) {
for (int u=tId; u<mFastBlit.size(); u+=threadNum) {
auto& iter = mFastBlit[u];
auto& slice = iter.second;
//Offset use byte
auto srcPtr = (uint8_t*)iter.first + slice.src.offset * bytes;
auto dstPtr = (uint8_t*)mOutputPtr + slice.dst.offset * bytes;
if (slice.src.stride[1] == slice.size[2] && slice.dst.stride[1] == slice.size[2] && slice.src.stride[2] == 1) {
for (int z=0; z<slice.size[0]; ++z) {
auto srcZ = srcPtr + z * slice.src.stride[0] * byteC4;
auto dstZ = dstPtr + z * slice.dst.stride[0] * byteC4;
::memcpy(dstZ, srcZ, slice.size[1] * slice.src.stride[1] * byteC4);
for (int i = 0; i < mFastBlit.size(); i++) {
auto& iter = mFastBlit[i];
auto& slice = iter.second;
//Offset use byte
auto srcPtr = (uint8_t*)iter.first + slice.src.offset * bytes;
auto dstPtr = (uint8_t*)mOutputPtr + slice.dst.offset * bytes;
if (slice.src.stride[1] == slice.size[2] && slice.dst.stride[1] == slice.size[2] && slice.src.stride[2] == 1) {
int subPatch = (slice.size[1] * slice.src.stride[1] * byteC4) / threadNum;
int extraPatch = slice.size[1] * slice.src.stride[1] * byteC4 - subPatch * threadNum;
MNN_CONCURRENCY_BEGIN(tId, threadNum) {
for (int z = 0; z < slice.size[0]; ++z) {
auto srcZ = srcPtr + subPatch * tId + z * slice.src.stride[0] * byteC4;
auto dstZ = dstPtr + subPatch * tId + z * slice.dst.stride[0] * byteC4;
::memcpy(dstZ, srcZ, subPatch);
}
continue;
}
if (1 == slice.src.stride[2] && 1 == slice.dst.stride[2]) {
for (int z=0; z<slice.size[0]; ++z) {
auto srcZ = srcPtr + z * slice.src.stride[0] * byteC4;
auto dstZ = dstPtr + z * slice.dst.stride[0] * byteC4;
for (int y=0; y<slice.size[1]; ++y) {
auto srcY = srcZ + y * slice.src.stride[1] * byteC4;
auto dstY = dstZ + y * slice.dst.stride[1] * byteC4;
::memcpy(dstY, srcY, slice.size[2] * byteC4);
}
MNN_CONCURRENCY_END();
if (extraPatch > 0) {
for (int z = 0; z < slice.size[0]; ++z) {
auto srcZ = srcPtr + subPatch * threadNum + z * slice.src.stride[0] * byteC4;
auto dstZ = dstPtr + subPatch * threadNum + z * slice.dst.stride[0] * byteC4;
::memcpy(dstZ, srcZ, extraPatch);
}
continue;
}
continue;
}
if (1 == slice.src.stride[2] && 1 == slice.dst.stride[2]) {
for (int z=0; z<slice.size[0]; ++z) {
auto srcZ = srcPtr + z * slice.src.stride[0] * byteC4;
auto dstZ = dstPtr + z * slice.dst.stride[0] * byteC4;
for (int y=0; y<slice.size[1]; ++y) {
MNN_CONCURRENCY_BEGIN(tId, threadNum) {
for (int y = tId; y < slice.size[1]; y += threadNum) {
auto srcY = srcZ + y * slice.src.stride[1] * byteC4;
auto dstY = dstZ + y * slice.dst.stride[1] * byteC4;
::memcpy(dstY, srcY, slice.size[2] * byteC4);
}
}
MNN_CONCURRENCY_END();
}
continue;
}
for (int z = 0; z < slice.size[0]; ++z) {
auto srcZ = srcPtr + z * slice.src.stride[0] * byteC4;
auto dstZ = dstPtr + z * slice.dst.stride[0] * byteC4;
MNN_CONCURRENCY_BEGIN(tId, threadNum) {
for (int y = tId; y < slice.size[1]; y += threadNum) {
auto srcY = srcZ + y * slice.src.stride[1] * byteC4;
auto dstY = dstZ + y * slice.dst.stride[1] * byteC4;
C4proc(dstY, srcY, slice.size[2], slice.src.stride[2], slice.dst.stride[2]);
}
}
MNN_CONCURRENCY_END();
}
}
MNN_CONCURRENCY_END();
return NO_ERROR;
}
for (auto& iter : mTempInput) {
backend()->onCopyBuffer(iter.first, iter.second.get());
}
@ -277,44 +295,61 @@ ErrorCode Arm82Raster::onExecute(const std::vector<Tensor *> &inputs, const std:
MNN_ASSERT(false);
break;
}
MNN_CONCURRENCY_BEGIN(tId, threadNum) {
for (int u=(int)tId; u<mTempInputCopy.size(); u+=threadNum) {
auto& iter = mTempInputCopy[u];
auto& slice = *(iter.second);
auto srcPtr = (uint8_t*)iter.first + slice.src.offset * bytes;
auto dstPtr = (uint8_t*)mOutputPtr + slice.dst.offset * bytes;
if (slice.src.stride[1] == slice.size[2] && slice.dst.stride[1] == slice.size[2] && slice.src.stride[2] == 1) {
for (int z=0; z<slice.size[0]; ++z) {
auto srcZ = srcPtr + z * slice.src.stride[0] * bytes;
auto dstZ = dstPtr + z * slice.dst.stride[0] * bytes;
::memcpy(dstZ, srcZ, slice.size[1] * slice.src.stride[1] * bytes);
for (int i = 0; i < mTempInputCopy.size(); i++) {
auto& iter = mTempInputCopy[i];
auto& slice = *(iter.second);
auto srcPtr = (uint8_t*)iter.first + slice.src.offset * bytes;
auto dstPtr = (uint8_t*)mOutputPtr + slice.dst.offset * bytes;
if (slice.src.stride[1] == slice.size[2] && slice.dst.stride[1] == slice.size[2] && slice.src.stride[2] == 1) {
int subPatch = (slice.size[1] * slice.src.stride[1] * bytes) / threadNum;
int extraPatch = slice.size[1] * slice.src.stride[1] * bytes - subPatch * threadNum;
MNN_CONCURRENCY_BEGIN(tId, threadNum) {
for (int z = 0; z < slice.size[0]; ++z) {
auto srcZ = srcPtr + subPatch * tId + z * slice.src.stride[0] * bytes;
auto dstZ = dstPtr + subPatch * tId + z * slice.dst.stride[0] * bytes;
::memcpy(dstZ, srcZ, subPatch);
}
continue;
}
if (1 == slice.src.stride[2] && 1 == slice.dst.stride[2]) {
for (int z=0; z<slice.size[0]; ++z) {
auto srcZ = srcPtr + z * slice.src.stride[0] * bytes;
auto dstZ = dstPtr + z * slice.dst.stride[0] * bytes;
for (int y=0; y<slice.size[1]; ++y) {
MNN_CONCURRENCY_END();
if (extraPatch > 0) {
for (int z = 0; z < slice.size[0]; ++z) {
auto srcZ = srcPtr + subPatch * threadNum + z * slice.src.stride[0] * bytes;
auto dstZ = dstPtr + subPatch * threadNum + z * slice.dst.stride[0] * bytes;
::memcpy(dstZ, srcZ, extraPatch);
}
}
continue;
}
if (1 == slice.src.stride[2] && 1 == slice.dst.stride[2]) {
for (int z = 0; z < slice.size[0]; ++z) {
auto srcZ = srcPtr + z * slice.src.stride[0] * bytes;
auto dstZ = dstPtr + z * slice.dst.stride[0] * bytes;
MNN_CONCURRENCY_BEGIN(tId, threadNum) {
for (int y = tId; y < slice.size[1]; y += threadNum) {
auto srcY = srcZ + y * slice.src.stride[1] * bytes;
auto dstY = dstZ + y * slice.dst.stride[1] * bytes;
::memcpy(dstY, srcY, slice.size[2] * bytes);
}
}
continue;
MNN_CONCURRENCY_END();
}
for (int z=0; z<slice.size[0]; ++z) {
auto srcZ = srcPtr + z * slice.src.stride[0] * bytes;
auto dstZ = dstPtr + (z) * slice.dst.stride[0] * bytes;
for (int y=0; y<slice.size[1]; ++y) {
continue;
}
for (int z = 0; z < slice.size[0]; ++z) {
auto srcZ = srcPtr + z * slice.src.stride[0] * bytes;
auto dstZ = dstPtr + (z) * slice.dst.stride[0] * bytes;
MNN_CONCURRENCY_BEGIN(tId, threadNum) {
for (int y = tId; y < slice.size[1]; y += threadNum) {
auto srcY = srcZ + y * slice.src.stride[1] * bytes;
auto dstY = dstZ + y * slice.dst.stride[1] * bytes;
proc(dstY, srcY, slice.size[2], slice.src.stride[2], slice.dst.stride[2]);
}
}
MNN_CONCURRENCY_END();
}
}
MNN_CONCURRENCY_END();
if (nullptr != mTempOutput) {
backend()->onCopyBuffer(mTempOutput.get(), output);
}