improvement(arm82): optimize arm82 raster op use multi-threads

This commit is contained in:
gaoruidong 2020-12-23 17:47:34 +08:00
parent b0dbe49776
commit caeab08754
1 changed files with 85 additions and 50 deletions

View File

@ -183,7 +183,6 @@ ErrorCode Arm82Raster::onExecute(const std::vector<Tensor *> &inputs, const std:
bytes = 2; bytes = 2;
} }
auto threadNum = static_cast<Arm82Backend*>(backend())->numberThread(); auto threadNum = static_cast<Arm82Backend*>(backend())->numberThread();
if (mNeedZero) { if (mNeedZero) {
auto size = bytes; auto size = bytes;
const int dimensions = input->dimensions(); const int dimensions = input->dimensions();
@ -217,48 +216,67 @@ ErrorCode Arm82Raster::onExecute(const std::vector<Tensor *> &inputs, const std:
break; break;
} }
auto byteC4 = bytes * 8; auto byteC4 = bytes * 8;
MNN_CONCURRENCY_BEGIN(tId, threadNum) {
for (int u=tId; u<mFastBlit.size(); u+=threadNum) { for (int i = 0; i < mFastBlit.size(); i++) {
auto& iter = mFastBlit[u]; auto& iter = mFastBlit[i];
auto& slice = iter.second; auto& slice = iter.second;
//Offset use byte
auto srcPtr = (uint8_t*)iter.first + slice.src.offset * bytes; //Offset use byte
auto dstPtr = (uint8_t*)mOutputPtr + slice.dst.offset * bytes; auto srcPtr = (uint8_t*)iter.first + slice.src.offset * bytes;
if (slice.src.stride[1] == slice.size[2] && slice.dst.stride[1] == slice.size[2] && slice.src.stride[2] == 1) { auto dstPtr = (uint8_t*)mOutputPtr + slice.dst.offset * bytes;
for (int z=0; z<slice.size[0]; ++z) { if (slice.src.stride[1] == slice.size[2] && slice.dst.stride[1] == slice.size[2] && slice.src.stride[2] == 1) {
auto srcZ = srcPtr + z * slice.src.stride[0] * byteC4; int subPatch = (slice.size[1] * slice.src.stride[1] * byteC4) / threadNum;
auto dstZ = dstPtr + z * slice.dst.stride[0] * byteC4; int extraPatch = slice.size[1] * slice.src.stride[1] * byteC4 - subPatch * threadNum;
::memcpy(dstZ, srcZ, slice.size[1] * slice.src.stride[1] * byteC4); MNN_CONCURRENCY_BEGIN(tId, threadNum) {
for (int z = 0; z < slice.size[0]; ++z) {
auto srcZ = srcPtr + subPatch * tId + z * slice.src.stride[0] * byteC4;
auto dstZ = dstPtr + subPatch * tId + z * slice.dst.stride[0] * byteC4;
::memcpy(dstZ, srcZ, subPatch);
} }
continue;
} }
if (1 == slice.src.stride[2] && 1 == slice.dst.stride[2]) { MNN_CONCURRENCY_END();
for (int z=0; z<slice.size[0]; ++z) { if (extraPatch > 0) {
auto srcZ = srcPtr + z * slice.src.stride[0] * byteC4; for (int z = 0; z < slice.size[0]; ++z) {
auto dstZ = dstPtr + z * slice.dst.stride[0] * byteC4; auto srcZ = srcPtr + subPatch * threadNum + z * slice.src.stride[0] * byteC4;
for (int y=0; y<slice.size[1]; ++y) { auto dstZ = dstPtr + subPatch * threadNum + z * slice.dst.stride[0] * byteC4;
auto srcY = srcZ + y * slice.src.stride[1] * byteC4; ::memcpy(dstZ, srcZ, extraPatch);
auto dstY = dstZ + y * slice.dst.stride[1] * byteC4;
::memcpy(dstY, srcY, slice.size[2] * byteC4);
}
} }
continue;
} }
continue;
}
if (1 == slice.src.stride[2] && 1 == slice.dst.stride[2]) {
for (int z=0; z<slice.size[0]; ++z) { for (int z=0; z<slice.size[0]; ++z) {
auto srcZ = srcPtr + z * slice.src.stride[0] * byteC4; auto srcZ = srcPtr + z * slice.src.stride[0] * byteC4;
auto dstZ = dstPtr + z * slice.dst.stride[0] * byteC4; auto dstZ = dstPtr + z * slice.dst.stride[0] * byteC4;
for (int y=0; y<slice.size[1]; ++y) { MNN_CONCURRENCY_BEGIN(tId, threadNum) {
for (int y = tId; y < slice.size[1]; y += threadNum) {
auto srcY = srcZ + y * slice.src.stride[1] * byteC4;
auto dstY = dstZ + y * slice.dst.stride[1] * byteC4;
::memcpy(dstY, srcY, slice.size[2] * byteC4);
}
}
MNN_CONCURRENCY_END();
}
continue;
}
for (int z = 0; z < slice.size[0]; ++z) {
auto srcZ = srcPtr + z * slice.src.stride[0] * byteC4;
auto dstZ = dstPtr + z * slice.dst.stride[0] * byteC4;
MNN_CONCURRENCY_BEGIN(tId, threadNum) {
for (int y = tId; y < slice.size[1]; y += threadNum) {
auto srcY = srcZ + y * slice.src.stride[1] * byteC4; auto srcY = srcZ + y * slice.src.stride[1] * byteC4;
auto dstY = dstZ + y * slice.dst.stride[1] * byteC4; auto dstY = dstZ + y * slice.dst.stride[1] * byteC4;
C4proc(dstY, srcY, slice.size[2], slice.src.stride[2], slice.dst.stride[2]); C4proc(dstY, srcY, slice.size[2], slice.src.stride[2], slice.dst.stride[2]);
} }
} }
MNN_CONCURRENCY_END();
} }
} }
MNN_CONCURRENCY_END();
return NO_ERROR; return NO_ERROR;
} }
for (auto& iter : mTempInput) { for (auto& iter : mTempInput) {
backend()->onCopyBuffer(iter.first, iter.second.get()); backend()->onCopyBuffer(iter.first, iter.second.get());
} }
@ -277,44 +295,61 @@ ErrorCode Arm82Raster::onExecute(const std::vector<Tensor *> &inputs, const std:
MNN_ASSERT(false); MNN_ASSERT(false);
break; break;
} }
MNN_CONCURRENCY_BEGIN(tId, threadNum) {
for (int u=(int)tId; u<mTempInputCopy.size(); u+=threadNum) { for (int i = 0; i < mTempInputCopy.size(); i++) {
auto& iter = mTempInputCopy[u]; auto& iter = mTempInputCopy[i];
auto& slice = *(iter.second); auto& slice = *(iter.second);
auto srcPtr = (uint8_t*)iter.first + slice.src.offset * bytes; auto srcPtr = (uint8_t*)iter.first + slice.src.offset * bytes;
auto dstPtr = (uint8_t*)mOutputPtr + slice.dst.offset * bytes; auto dstPtr = (uint8_t*)mOutputPtr + slice.dst.offset * bytes;
if (slice.src.stride[1] == slice.size[2] && slice.dst.stride[1] == slice.size[2] && slice.src.stride[2] == 1) { if (slice.src.stride[1] == slice.size[2] && slice.dst.stride[1] == slice.size[2] && slice.src.stride[2] == 1) {
for (int z=0; z<slice.size[0]; ++z) { int subPatch = (slice.size[1] * slice.src.stride[1] * bytes) / threadNum;
auto srcZ = srcPtr + z * slice.src.stride[0] * bytes; int extraPatch = slice.size[1] * slice.src.stride[1] * bytes - subPatch * threadNum;
auto dstZ = dstPtr + z * slice.dst.stride[0] * bytes; MNN_CONCURRENCY_BEGIN(tId, threadNum) {
::memcpy(dstZ, srcZ, slice.size[1] * slice.src.stride[1] * bytes); for (int z = 0; z < slice.size[0]; ++z) {
auto srcZ = srcPtr + subPatch * tId + z * slice.src.stride[0] * bytes;
auto dstZ = dstPtr + subPatch * tId + z * slice.dst.stride[0] * bytes;
::memcpy(dstZ, srcZ, subPatch);
} }
continue;
} }
if (1 == slice.src.stride[2] && 1 == slice.dst.stride[2]) { MNN_CONCURRENCY_END();
for (int z=0; z<slice.size[0]; ++z) { if (extraPatch > 0) {
auto srcZ = srcPtr + z * slice.src.stride[0] * bytes; for (int z = 0; z < slice.size[0]; ++z) {
auto dstZ = dstPtr + z * slice.dst.stride[0] * bytes; auto srcZ = srcPtr + subPatch * threadNum + z * slice.src.stride[0] * bytes;
for (int y=0; y<slice.size[1]; ++y) { auto dstZ = dstPtr + subPatch * threadNum + z * slice.dst.stride[0] * bytes;
::memcpy(dstZ, srcZ, extraPatch);
}
}
continue;
}
if (1 == slice.src.stride[2] && 1 == slice.dst.stride[2]) {
for (int z = 0; z < slice.size[0]; ++z) {
auto srcZ = srcPtr + z * slice.src.stride[0] * bytes;
auto dstZ = dstPtr + z * slice.dst.stride[0] * bytes;
MNN_CONCURRENCY_BEGIN(tId, threadNum) {
for (int y = tId; y < slice.size[1]; y += threadNum) {
auto srcY = srcZ + y * slice.src.stride[1] * bytes; auto srcY = srcZ + y * slice.src.stride[1] * bytes;
auto dstY = dstZ + y * slice.dst.stride[1] * bytes; auto dstY = dstZ + y * slice.dst.stride[1] * bytes;
::memcpy(dstY, srcY, slice.size[2] * bytes); ::memcpy(dstY, srcY, slice.size[2] * bytes);
} }
} }
continue; MNN_CONCURRENCY_END();
} }
for (int z=0; z<slice.size[0]; ++z) { continue;
auto srcZ = srcPtr + z * slice.src.stride[0] * bytes; }
auto dstZ = dstPtr + (z) * slice.dst.stride[0] * bytes; for (int z = 0; z < slice.size[0]; ++z) {
for (int y=0; y<slice.size[1]; ++y) { auto srcZ = srcPtr + z * slice.src.stride[0] * bytes;
auto dstZ = dstPtr + (z) * slice.dst.stride[0] * bytes;
MNN_CONCURRENCY_BEGIN(tId, threadNum) {
for (int y = tId; y < slice.size[1]; y += threadNum) {
auto srcY = srcZ + y * slice.src.stride[1] * bytes; auto srcY = srcZ + y * slice.src.stride[1] * bytes;
auto dstY = dstZ + y * slice.dst.stride[1] * bytes; auto dstY = dstZ + y * slice.dst.stride[1] * bytes;
proc(dstY, srcY, slice.size[2], slice.src.stride[2], slice.dst.stride[2]); proc(dstY, srcY, slice.size[2], slice.src.stride[2], slice.dst.stride[2]);
} }
} }
MNN_CONCURRENCY_END();
} }
} }
MNN_CONCURRENCY_END();
if (nullptr != mTempOutput) { if (nullptr != mTempOutput) {
backend()->onCopyBuffer(mTempOutput.get(), output); backend()->onCopyBuffer(mTempOutput.get(), output);
} }