mirror of https://github.com/alibaba/MNN.git
				
				
				
			improvement(arm82): optimize arm82 raster op use multi-threads
This commit is contained in:
		
							parent
							
								
									b0dbe49776
								
							
						
					
					
						commit
						caeab08754
					
				|  | @ -183,7 +183,6 @@ ErrorCode Arm82Raster::onExecute(const std::vector<Tensor *> &inputs, const std: | |||
|         bytes = 2; | ||||
|     } | ||||
|     auto threadNum = static_cast<Arm82Backend*>(backend())->numberThread(); | ||||
| 
 | ||||
|     if (mNeedZero) { | ||||
|         auto size          = bytes; | ||||
|         const int dimensions = input->dimensions(); | ||||
|  | @ -217,48 +216,67 @@ ErrorCode Arm82Raster::onExecute(const std::vector<Tensor *> &inputs, const std: | |||
|                 break; | ||||
|         } | ||||
|         auto byteC4 = bytes * 8; | ||||
|         MNN_CONCURRENCY_BEGIN(tId, threadNum) { | ||||
|             for (int u=tId; u<mFastBlit.size(); u+=threadNum) { | ||||
|                 auto& iter = mFastBlit[u]; | ||||
| 
 | ||||
|         for (int i = 0; i < mFastBlit.size(); i++) { | ||||
|             auto& iter = mFastBlit[i]; | ||||
|             auto& slice = iter.second; | ||||
| 
 | ||||
|             //Offset use byte
 | ||||
|             auto srcPtr = (uint8_t*)iter.first + slice.src.offset * bytes; | ||||
|             auto dstPtr = (uint8_t*)mOutputPtr + slice.dst.offset * bytes; | ||||
|             if (slice.src.stride[1] == slice.size[2] && slice.dst.stride[1] == slice.size[2] && slice.src.stride[2] == 1) { | ||||
|                 int subPatch = (slice.size[1] * slice.src.stride[1] * byteC4) / threadNum;  | ||||
|                 int extraPatch = slice.size[1] * slice.src.stride[1] * byteC4 - subPatch * threadNum; | ||||
|                 MNN_CONCURRENCY_BEGIN(tId, threadNum) { | ||||
|                     for (int z = 0; z < slice.size[0]; ++z) { | ||||
|                         auto srcZ = srcPtr + z * slice.src.stride[0] * byteC4; | ||||
|                         auto dstZ = dstPtr + z * slice.dst.stride[0] * byteC4; | ||||
|                         ::memcpy(dstZ, srcZ, slice.size[1] * slice.src.stride[1] * byteC4); | ||||
|                         auto srcZ = srcPtr + subPatch * tId + z * slice.src.stride[0] * byteC4; | ||||
|                         auto dstZ = dstPtr + subPatch * tId + z * slice.dst.stride[0] * byteC4; | ||||
|                         ::memcpy(dstZ, srcZ, subPatch); | ||||
|                     } | ||||
|                 } | ||||
|                 MNN_CONCURRENCY_END(); | ||||
|                 if (extraPatch > 0) { | ||||
|                     for (int z = 0; z < slice.size[0]; ++z) { | ||||
|                         auto srcZ = srcPtr + subPatch * threadNum + z * slice.src.stride[0] * byteC4; | ||||
|                         auto dstZ = dstPtr + subPatch * threadNum + z * slice.dst.stride[0] * byteC4; | ||||
|                         ::memcpy(dstZ, srcZ, extraPatch); | ||||
|                     } | ||||
|                 } | ||||
|                 continue; | ||||
|             } | ||||
| 
 | ||||
|             if (1 == slice.src.stride[2] && 1 == slice.dst.stride[2]) { | ||||
|                 for (int z=0; z<slice.size[0]; ++z) { | ||||
|                     auto srcZ = srcPtr + z * slice.src.stride[0] * byteC4; | ||||
|                     auto dstZ = dstPtr + z * slice.dst.stride[0] * byteC4; | ||||
|                         for (int y=0; y<slice.size[1]; ++y) { | ||||
|                     MNN_CONCURRENCY_BEGIN(tId, threadNum) { | ||||
|                         for (int y = tId; y < slice.size[1]; y += threadNum) { | ||||
|                             auto srcY = srcZ + y * slice.src.stride[1] * byteC4; | ||||
|                             auto dstY = dstZ + y * slice.dst.stride[1] * byteC4; | ||||
|                             ::memcpy(dstY, srcY, slice.size[2] * byteC4); | ||||
|                         }  | ||||
|                     } | ||||
|                     MNN_CONCURRENCY_END();                        | ||||
|                 } | ||||
|                 continue; | ||||
|             } | ||||
| 
 | ||||
|             for (int z = 0; z < slice.size[0]; ++z) { | ||||
|                 auto srcZ = srcPtr + z * slice.src.stride[0] * byteC4; | ||||
|                 auto dstZ = dstPtr + z * slice.dst.stride[0] * byteC4; | ||||
|                     for (int y=0; y<slice.size[1]; ++y) { | ||||
|                 MNN_CONCURRENCY_BEGIN(tId, threadNum) { | ||||
|                     for (int y = tId; y < slice.size[1]; y += threadNum) { | ||||
|                         auto srcY = srcZ + y * slice.src.stride[1] * byteC4; | ||||
|                         auto dstY = dstZ + y * slice.dst.stride[1] * byteC4; | ||||
|                         C4proc(dstY, srcY, slice.size[2], slice.src.stride[2], slice.dst.stride[2]); | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|                 MNN_CONCURRENCY_END(); | ||||
| 
 | ||||
|             } | ||||
|         } | ||||
|         return NO_ERROR; | ||||
|     } | ||||
| 
 | ||||
|     for (auto& iter : mTempInput) { | ||||
|         backend()->onCopyBuffer(iter.first, iter.second.get()); | ||||
|     } | ||||
|  | @ -277,17 +295,29 @@ ErrorCode Arm82Raster::onExecute(const std::vector<Tensor *> &inputs, const std: | |||
|             MNN_ASSERT(false); | ||||
|             break; | ||||
|     } | ||||
|     MNN_CONCURRENCY_BEGIN(tId, threadNum) { | ||||
|         for (int u=(int)tId; u<mTempInputCopy.size(); u+=threadNum) { | ||||
|             auto& iter = mTempInputCopy[u]; | ||||
| 
 | ||||
|     for (int i = 0; i < mTempInputCopy.size(); i++) { | ||||
|         auto& iter = mTempInputCopy[i]; | ||||
|         auto& slice = *(iter.second); | ||||
|         auto srcPtr = (uint8_t*)iter.first + slice.src.offset * bytes; | ||||
|         auto dstPtr = (uint8_t*)mOutputPtr + slice.dst.offset * bytes; | ||||
|         if (slice.src.stride[1] == slice.size[2] && slice.dst.stride[1] == slice.size[2] && slice.src.stride[2] == 1) { | ||||
|             int subPatch = (slice.size[1] * slice.src.stride[1] * bytes) / threadNum; | ||||
|             int extraPatch = slice.size[1] * slice.src.stride[1] * bytes - subPatch * threadNum; | ||||
|             MNN_CONCURRENCY_BEGIN(tId, threadNum) { | ||||
|                 for (int z = 0; z < slice.size[0]; ++z) { | ||||
|                     auto srcZ = srcPtr + z * slice.src.stride[0] * bytes; | ||||
|                     auto dstZ = dstPtr + z * slice.dst.stride[0] * bytes; | ||||
|                     ::memcpy(dstZ, srcZ, slice.size[1] * slice.src.stride[1] * bytes); | ||||
|                     auto srcZ = srcPtr + subPatch * tId + z * slice.src.stride[0] * bytes; | ||||
|                     auto dstZ = dstPtr + subPatch * tId + z * slice.dst.stride[0] * bytes; | ||||
|                     ::memcpy(dstZ, srcZ, subPatch); | ||||
|                 } | ||||
|             } | ||||
|             MNN_CONCURRENCY_END(); | ||||
|             if (extraPatch > 0) { | ||||
|                 for (int z = 0; z < slice.size[0]; ++z) { | ||||
|                     auto srcZ = srcPtr + subPatch * threadNum + z * slice.src.stride[0] * bytes; | ||||
|                     auto dstZ = dstPtr + subPatch * threadNum + z * slice.dst.stride[0] * bytes; | ||||
|                     ::memcpy(dstZ, srcZ, extraPatch); | ||||
|                 } | ||||
|             } | ||||
|             continue; | ||||
|         } | ||||
|  | @ -295,26 +325,31 @@ ErrorCode Arm82Raster::onExecute(const std::vector<Tensor *> &inputs, const std: | |||
|             for (int z = 0; z < slice.size[0]; ++z) { | ||||
|                 auto srcZ = srcPtr + z * slice.src.stride[0] * bytes; | ||||
|                 auto dstZ = dstPtr + z * slice.dst.stride[0] * bytes; | ||||
|                     for (int y=0; y<slice.size[1]; ++y) { | ||||
|                 MNN_CONCURRENCY_BEGIN(tId, threadNum) { | ||||
|                     for (int y = tId; y < slice.size[1]; y += threadNum) { | ||||
|                         auto srcY = srcZ + y * slice.src.stride[1] * bytes; | ||||
|                         auto dstY = dstZ + y * slice.dst.stride[1] * bytes; | ||||
|                         ::memcpy(dstY, srcY, slice.size[2] * bytes); | ||||
|                     } | ||||
|                 } | ||||
|                 MNN_CONCURRENCY_END(); | ||||
|             } | ||||
|             continue; | ||||
|         } | ||||
|         for (int z = 0; z < slice.size[0]; ++z) { | ||||
|             auto srcZ = srcPtr + z * slice.src.stride[0] * bytes; | ||||
|             auto dstZ = dstPtr + (z) * slice.dst.stride[0] * bytes; | ||||
|                 for (int y=0; y<slice.size[1]; ++y) { | ||||
|             MNN_CONCURRENCY_BEGIN(tId, threadNum) { | ||||
|                 for (int y = tId; y < slice.size[1]; y += threadNum) { | ||||
|                     auto srcY = srcZ + y * slice.src.stride[1] * bytes; | ||||
|                     auto dstY = dstZ + y * slice.dst.stride[1] * bytes; | ||||
|                     proc(dstY, srcY, slice.size[2], slice.src.stride[2], slice.dst.stride[2]); | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|             MNN_CONCURRENCY_END(); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     if (nullptr != mTempOutput) { | ||||
|         backend()->onCopyBuffer(mTempOutput.get(), output); | ||||
|     } | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue