mirror of https://github.com/alibaba/MNN.git
				
				
				
			
		
			
				
	
	
		
			327 lines
		
	
	
		
			15 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
			
		
		
	
	
			327 lines
		
	
	
		
			15 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
| #include "TopKV2Execution.hpp"
 | |
| #include <memory>
 | |
| 
 | |
| namespace MNN {
 | |
| namespace CUDA {
 | |
| 
 | |
| 
 | |
| // rank TopK in the corresponding thead
 | |
| template<typename indexT, typename valueT>
 | |
| __device__ void TopKInThread(const valueT * inputDevice, indexT * indicesThread, valueT * valuesThread, const int K, const int numElePerRow, const valueT minValue, const int descendFlag) {
 | |
|     for (int i = 0 ; i < K; i++) {
 | |
|         indicesThread[i] = -1;
 | |
|         valuesThread[i] = (valueT)(descendFlag) * minValue;
 | |
|     }
 | |
| 
 | |
|     int idxFirstEleInRow = threadIdx.x + blockIdx.x * blockDim.x;
 | |
| 
 | |
|     for (indexT i =  idxFirstEleInRow; i < numElePerRow; i += gridDim.x * blockDim.x) {
 | |
|         valueT data = inputDevice[i];
 | |
|         if ((valueT)(descendFlag) * data <= (valueT)(descendFlag) * valuesThread[K - 1]) {
 | |
|             continue;
 | |
|         } else {
 | |
|             for (int j = K - 2; j >= 0; j--) {
 | |
|                 if ((valueT)(descendFlag) * data > (valueT)(descendFlag) * valuesThread[j]) {
 | |
|                     valuesThread[j + 1] = valuesThread[j];
 | |
|                     indicesThread[j + 1] = indicesThread[j];
 | |
|                     valuesThread[j] = data;
 | |
|                     indicesThread[j] = i;
 | |
|                 } else {
 | |
|                     break;
 | |
|                 }
 | |
|             }
 | |
|         }
 | |
|     }
 | |
| 
 | |
|     return;
 | |
| }
 | |
| 
 | |
| 
 | |
| // reduce TopK results of two offsets
 | |
| template<typename indexT, typename valueT>
 | |
| __device__ void ReduceTopK(indexT * indicesArray, valueT * valuesArray, const int offset1, const int offset2, const int K, const int descendFlag) {
 | |
|     indexT idx1 = offset1 + K - 1;
 | |
|     indexT idx2 = offset2 + K - 1;
 | |
|     indexT idxVirtual = offset1 + 2 * K -1;
 | |
| 
 | |
|     while (idx2 >= offset2) {
 | |
|         if (idx1 < offset1) {
 | |
|             while (idxVirtual >= offset1) {
 | |
|                 indicesArray[idxVirtual] = indicesArray[offset2 + (idxVirtual - offset1)];
 | |
|                 valuesArray[idxVirtual] = valuesArray[offset2 + (idxVirtual - offset1)];
 | |
|                 idxVirtual --;
 | |
|             }
 | |
|             break;
 | |
|         }
 | |
| 
 | |
|         if ((valueT)(descendFlag) * valuesArray[idx1] <= (valueT)(descendFlag) * valuesArray[idx2]) {
 | |
|             if (idxVirtual <= offset1 + K - 1) {
 | |
|                 indicesArray[idxVirtual] = indicesArray[idx1];
 | |
|                 valuesArray[idxVirtual] = valuesArray[idx1];
 | |
|             }
 | |
|             idx1 --;
 | |
|         } else {
 | |
|             if (idxVirtual <= offset1 + K - 1) {
 | |
|                 indicesArray[idxVirtual] = indicesArray[idx2];
 | |
|                 valuesArray[idxVirtual] = valuesArray[idx2];
 | |
|             }
 | |
|             idx2 --;
 | |
|         }
 | |
|         idxVirtual --;
 | |
|     }
 | |
| 
 | |
|     return;
 | |
| }
 | |
| 
 | |
| 
 | |
| // get results of all blocks' TopK in one row
 | |
| template<typename indexT, typename valueT>
 | |
| __device__ void TopKOneRow(const valueT * inputDevice, indexT * indicesBlock, valueT * valuesBlock, indexT * tempIndicesDevice, valueT * tempValuesDevice, const int K, const int lengthRow, valueT minValue, const int descendFlag) {
 | |
|     indexT * indicesThread = indicesBlock + threadIdx.x * K;
 | |
|     valueT * valuesThread = valuesBlock + threadIdx.x * K;
 | |
| 
 | |
|     // rank TopK
 | |
|     TopKInThread<indexT, valueT>(inputDevice, indicesThread, valuesThread, K, lengthRow, minValue, descendFlag);
 | |
| 
 | |
|     __syncthreads();
 | |
| 
 | |
|     // reduce
 | |
|     for(int stride = (blockDim.x >> 1); stride > 0; stride >>= 1) {
 | |
|         if(threadIdx.x < stride) {
 | |
|             ReduceTopK<indexT, valueT>(indicesBlock, valuesBlock, threadIdx.x * K, (threadIdx.x + stride) * K, K, descendFlag);
 | |
|         }
 | |
|         __syncthreads();
 | |
|     }
 | |
| 
 | |
|     // move data from block's smem to global memory(prepare for the next kernel function)
 | |
|     if (threadIdx.x == 0) {
 | |
|         for(int i = 0; i < K; i++) {
 | |
|             tempIndicesDevice[K * blockIdx.x + i] = indicesBlock[i];
 | |
|             tempValuesDevice[K * blockIdx.x + i] = valuesBlock[i];
 | |
|         }
 | |
|     }
 | |
| 
 | |
|     return;
 | |
| }
 | |
| 
 | |
| 
 | |
| // get results of the final TopK from all block's TopK in a row
 | |
| template<typename indexT, typename valueT>
 | |
| __device__ void GetResultOneRow(indexT * outputIndicesDevice, valueT * outputValuesDevice, indexT * tempIndicesDevice, valueT * tempValuesDevice, indexT * finalIndices, valueT * finalValues, const int K, const int reduceLength, const int descendFlag) {
 | |
|     // move data from global memory to a block's smem
 | |
|     if (threadIdx.x < reduceLength) {
 | |
|         for (int i = 0; i < K; i++) {
 | |
|             finalIndices[threadIdx.x * K + i] = tempIndicesDevice[threadIdx.x * K + i];
 | |
|             finalValues[threadIdx.x * K + i] = tempValuesDevice[threadIdx.x * K + i];
 | |
|         }
 | |
|     }
 | |
|     __syncthreads();
 | |
| 
 | |
|     // the first round of reducing needs special action
 | |
|     int stride = blockDim.x >> 1;
 | |
|     if ((threadIdx.x < stride) && (threadIdx.x + stride < reduceLength)) {
 | |
|         ReduceTopK<indexT, valueT>(finalIndices, finalValues, threadIdx.x * K, (threadIdx.x + stride) * K, K, descendFlag);
 | |
|     }
 | |
|     __syncthreads();
 | |
|     stride >>= 1;
 | |
| 
 | |
|     // the remaining rounds of reducing
 | |
|     for (; stride > 0; stride >>= 1) {
 | |
|         if (threadIdx.x < stride) {
 | |
|             ReduceTopK<indexT, valueT>(finalIndices, finalValues, threadIdx.x * K, (threadIdx.x + stride) * K, K, descendFlag);
 | |
|         }
 | |
|         __syncthreads();
 | |
|     }
 | |
| 
 | |
|     //move data from a block's smem to global memory
 | |
|     if (threadIdx.x == 0) {
 | |
|         for (int i = 0; i < K; i++) {
 | |
|             outputIndicesDevice[i] = finalIndices[i];
 | |
|             outputValuesDevice[i] = finalValues[i];
 | |
|         }
 | |
|     }
 | |
| 
 | |
|     return;
 | |
| }
 | |
| 
 | |
| 
 | |
| // allocate addresses for each row and call <TopKOneRow>
 | |
| template<typename indexT, typename valueT>
 | |
| __global__ void TopKAllRows(const valueT * inputDevice, indexT * tempIndicesDevice, valueT * tempValuesDevice, const int K, const int lengthRow, valueT minValue, const int descendFlag) {
 | |
|     extern __shared__ char smem[];
 | |
|     indexT * indicesBlock = reinterpret_cast<indexT *>(smem);
 | |
|     valueT * valuesBlock = reinterpret_cast<valueT *>(&smem[blockDim.x * K * sizeof(indexT)]);
 | |
| 
 | |
|     int idxRow = blockIdx.y;
 | |
| 
 | |
|     const valueT * inputDeviceThisRow = inputDevice + idxRow * lengthRow;
 | |
|     indexT * tempIndicesDeviceThisRow = tempIndicesDevice + idxRow * gridDim.x * K;
 | |
|     valueT * tempValuesDeviceThisRow = tempValuesDevice + idxRow * gridDim.x * K;
 | |
| 
 | |
|     TopKOneRow<indexT, valueT>(inputDeviceThisRow, indicesBlock, valuesBlock, tempIndicesDeviceThisRow, tempValuesDeviceThisRow, K, lengthRow, minValue, descendFlag);
 | |
|     __syncthreads();
 | |
| 
 | |
|     return;
 | |
| }
 | |
| 
 | |
| 
 | |
| // allocate addresses for each row and call <GetResultOneRow>
 | |
| // This kernel assumes that each row of data corresponds to one block.
 | |
| template<typename indexT, typename valueT>
 | |
| __global__ void GetResultAllRows(indexT * outputIndicesDevice, valueT * outputValuesDevice, indexT * tempIndicesDevice, valueT * tempValuesDevice, const int K, const int numBlockPerRow, const int descendFlag) {
 | |
|     extern __shared__ char smem[];
 | |
|     indexT * finalIndices = reinterpret_cast<indexT *>(smem);
 | |
|     valueT * finalValues = reinterpret_cast<valueT *>(&smem[numBlockPerRow * K * sizeof(indexT)]);
 | |
| 
 | |
|     int idxRow = blockIdx.x; // each block corresponds to a row
 | |
| 
 | |
|     indexT * outputIndicesDeviceThisRow = outputIndicesDevice + idxRow * K;
 | |
|     valueT * outputValuesDeviceThisRow = outputValuesDevice + idxRow * K;
 | |
|     indexT * tempIndicesDeviceThisRow = tempIndicesDevice + idxRow * numBlockPerRow * K;
 | |
|     valueT * tempValuesDeviceThisRow = tempValuesDevice + idxRow * numBlockPerRow * K;
 | |
| 
 | |
|     GetResultOneRow<indexT, valueT>(outputIndicesDeviceThisRow, outputValuesDeviceThisRow, tempIndicesDeviceThisRow, tempValuesDeviceThisRow, finalIndices, finalValues, K, numBlockPerRow, descendFlag);
 | |
| 
 | |
|     return;
 | |
| }
 | |
| 
 | |
| 
 | |
| // The inequality "numThreadPerBlock * K * (sizeof(indexT) + sizeof(valueT)) <= smemPerBlock" must be guaranteed, which means numThreadPerBlock depends on K.
 | |
| template<typename indexT, typename valueT>
 | |
| int CalculateNumThreadPerBlock(const int K, const int smemPerBlock) {
 | |
|     int temp = smemPerBlock / (K * (sizeof(indexT) + sizeof(valueT)));
 | |
|     int numCalculate = std::pow(2, (std::floor(std::log2(temp))));
 | |
|     int numLimit = 1024;
 | |
|     return ALIMIN(numLimit, numCalculate);
 | |
| }
 | |
| 
 | |
| 
 | |
| // The inequality "numBlockPerRow * K * (sizeof(indexT) + sizeof(valueT)) <= smemPerBlock" must be guaranteed by restricting numElePerThread.
 | |
| template<typename indexT, typename valueT>
 | |
| int CalcualteNumElePerThread(const int K, const int numElePerRow, const int numThreadPerBlock, const int smemPerBlock) {
 | |
|     int numLimit = K;
 | |
|     int numCalculate = UP_DIV(numElePerRow, (smemPerBlock / (K * (sizeof(indexT) + sizeof(valueT))))-1);
 | |
|     return ALIMAX(numLimit,numCalculate);
 | |
| }
 | |
| 
 | |
| 
 | |
| TopKV2Execution::TopKV2Execution(const Op* op, Backend* backend) : Execution(backend) {
 | |
|     mOp = op;
 | |
| }
 | |
| 
 | |
| 
 | |
| ErrorCode TopKV2Execution::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
 | |
|     // prepare some params for the kernel function
 | |
|     Tensor * inputTensor = inputs[0];
 | |
| 
 | |
|     int lengthRow = inputTensor->buffer().dim[inputTensor->buffer().dimensions - 1].extent;
 | |
|     int numRow = inputTensor->elementSize() / lengthRow;
 | |
| 
 | |
|     mParams.mLengthRow = lengthRow;
 | |
|     mParams.mNumRow = numRow;
 | |
| 
 | |
|     auto boolDescendFlag = mOp->main_as_TopKV2();
 | |
|     if (boolDescendFlag != nullptr) {
 | |
|         mParams.mDescendFlag = boolDescendFlag ? 1 : -1;
 | |
|     }
 | |
| 
 | |
|     mParams.mNumElePerRow = mParams.mLengthRow;
 | |
|     mParams.mNumK = outputs[0]->buffer().dim[outputs[0]->buffer().dimensions-1].extent;
 | |
|     auto smemLimit = static_cast<CUDABackend*>(backend())->getCUDARuntime()->smemPerBlock();
 | |
|     if (inputTensor->getType().code == halide_type_int && inputTensor->getType().bits == 32) {
 | |
|         mParams.mNumThreadPerBlock = CalculateNumThreadPerBlock<int, int>(mParams.mNumK, smemLimit);
 | |
|         mParams.mNumElePerThread = CalcualteNumElePerThread<int, int>(mParams.mNumK, mParams.mNumElePerRow, mParams.mNumThreadPerBlock, smemLimit);
 | |
|     } else if (static_cast<CUDABackend*>(backend())->useFp16()) {
 | |
|         mParams.mNumThreadPerBlock = CalculateNumThreadPerBlock<int, half>(mParams.mNumK, smemLimit);
 | |
|         mParams.mNumElePerThread = CalcualteNumElePerThread<int, half>(mParams.mNumK, mParams.mNumElePerRow, mParams.mNumThreadPerBlock, smemLimit);
 | |
|     } else {
 | |
|         mParams.mNumThreadPerBlock = CalculateNumThreadPerBlock<int, float>(mParams.mNumK, smemLimit);
 | |
|         mParams.mNumElePerThread = CalcualteNumElePerThread<int, float>(mParams.mNumK, mParams.mNumElePerRow, mParams.mNumThreadPerBlock, smemLimit);
 | |
|     }
 | |
|     mParams.mNumElePerBlock = mParams.mNumElePerThread * mParams.mNumThreadPerBlock;
 | |
|     mParams.mNumBlockPerRow = (mParams.mNumElePerRow - 1 + mParams.mNumElePerBlock) / mParams.mNumElePerBlock;
 | |
|     mParams.mNumBlockFinal = mParams.mNumRow;
 | |
|     mParams.mNumThreadFinal = std::pow(2, (std::ceil(std::log2(mParams.mNumBlockPerRow))));
 | |
|     mParams.mNumBlockTotal = mParams.mNumBlockPerRow * mParams.mNumRow;
 | |
| 
 | |
|     // prepare temp buffer
 | |
|     auto pool = static_cast<CUDABackend*>(backend())->getBufferPool();
 | |
| 
 | |
|     if (inputTensor->getType().code == halide_type_int && inputTensor->getType().bits == 32) {
 | |
|         auto bufferIndices = pool->alloc(mParams.mNumBlockTotal * mParams.mNumK * sizeof(int));
 | |
|         mParams.mBufferIndices = (void*)((uint8_t*)bufferIndices.first + bufferIndices.second);
 | |
|         auto  bufferValues = pool->alloc(mParams.mNumBlockTotal * mParams.mNumK * sizeof(int));
 | |
|         mParams.mBufferValues = (void*)((uint8_t*)bufferValues.first + bufferValues.second);
 | |
|         pool->free(bufferIndices);
 | |
|         pool->free(bufferValues);
 | |
|     } else if (static_cast<CUDABackend*>(backend())->useFp16()) {
 | |
|         auto bufferIndices = pool->alloc(mParams.mNumBlockTotal * mParams.mNumK * sizeof(int));
 | |
|         mParams.mBufferIndices = (void*)((uint8_t*)bufferIndices.first + bufferIndices.second);
 | |
|         auto bufferValues = pool->alloc(mParams.mNumBlockTotal * mParams.mNumK * sizeof(half));
 | |
|         mParams.mBufferValues = (void*)((uint8_t*)bufferValues.first + bufferValues.second);
 | |
|         pool->free(bufferIndices);
 | |
|         pool->free(bufferValues);
 | |
|     } else {
 | |
|         auto bufferIndices = pool->alloc(mParams.mNumBlockTotal * mParams.mNumK * sizeof(int));
 | |
|         mParams.mBufferIndices = (void*)((uint8_t*)bufferIndices.first + bufferIndices.second);
 | |
|         auto bufferValues = pool->alloc(mParams.mNumBlockTotal * mParams.mNumK * sizeof(float));
 | |
|         mParams.mBufferValues = (void*)((uint8_t*)bufferValues.first + bufferValues.second);
 | |
|         pool->free(bufferIndices);
 | |
|         pool->free(bufferValues);
 | |
| 
 | |
|     }
 | |
| 
 | |
|     return NO_ERROR;
 | |
| }
 | |
| 
 | |
| 
 | |
| ErrorCode TopKV2Execution::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
 | |
|     // get input and output pointers
 | |
|     void * inputDeviceAddr = reinterpret_cast<void *>(inputs[0]->deviceId());
 | |
|     void * outputIndicesDeviceAddr = reinterpret_cast<void *>(outputs[1]->deviceId());
 | |
|     void * outputValuesDeviceAddr = reinterpret_cast<void *>(outputs[0]->deviceId());
 | |
| 
 | |
|     // configure threads
 | |
|     dim3 grid1 = {mParams.mNumBlockPerRow, mParams.mNumRow};
 | |
|     dim3 block1 = {mParams.mNumThreadPerBlock, 1};
 | |
|     int smemSize1 = mParams.mNumThreadPerBlock * mParams.mNumK;
 | |
|     dim3 grid2 = {mParams.mNumBlockFinal};
 | |
|     dim3 block2 = {mParams.mNumThreadFinal};
 | |
|     int smemSize2 = mParams.mNumBlockPerRow * mParams.mNumK;
 | |
| 
 | |
|     if (inputs[0]->getType().code == halide_type_int && inputs[0]->getType().bits == 32) {
 | |
|         TopKAllRows<int, int><<<grid1, block1, smemSize1 * (sizeof(int) + sizeof(int))>>>(static_cast<const int *>(inputDeviceAddr), static_cast<int *>(mParams.mBufferIndices), static_cast<int *>(mParams.mBufferValues), mParams.mNumK, mParams.mLengthRow, mParams.mMinInt, mParams.mDescendFlag);
 | |
|         checkKernelErrors;
 | |
|         GetResultAllRows<int, int><<<grid2, block2, smemSize2 * (sizeof(int) + sizeof(int))>>>(static_cast<int *>(outputIndicesDeviceAddr), static_cast<int *>(outputValuesDeviceAddr), static_cast<int *>(mParams.mBufferIndices), static_cast<int *>(mParams.mBufferValues), mParams.mNumK, mParams.mNumBlockPerRow, mParams.mDescendFlag);
 | |
|         checkKernelErrors;
 | |
|     } else if (static_cast<CUDABackend*>(backend())->useFp16()) {
 | |
|         TopKAllRows<int, half><<<grid1, block1, smemSize1 * (sizeof(half) + sizeof(int))>>>(static_cast<const half *>(inputDeviceAddr), static_cast<int *>(mParams.mBufferIndices), static_cast<half *>(mParams.mBufferValues), mParams.mNumK, mParams.mLengthRow, mParams.mMinHalf, mParams.mDescendFlag);
 | |
|         checkKernelErrors;
 | |
|         GetResultAllRows<int, half><<<grid2, block2, smemSize2 * (sizeof(half) + sizeof(int))>>>(static_cast<int *>(outputIndicesDeviceAddr), static_cast<half *>(outputValuesDeviceAddr), static_cast<int *>(mParams.mBufferIndices), static_cast<half *>(mParams.mBufferValues), mParams.mNumK, mParams.mNumBlockPerRow, mParams.mDescendFlag);
 | |
|         checkKernelErrors;
 | |
|     } else {
 | |
|         TopKAllRows<int, float><<<grid1, block1, smemSize1 * (sizeof(float) + sizeof(int))>>>(static_cast<const float *>(inputDeviceAddr), static_cast<int *>(mParams.mBufferIndices), static_cast<float *>(mParams.mBufferValues), mParams.mNumK, mParams.mLengthRow, mParams.mMinFloat, mParams.mDescendFlag);
 | |
|         checkKernelErrors;
 | |
|         GetResultAllRows<int, float><<<grid2, block2, smemSize2 * (sizeof(float) + sizeof(int))>>>(static_cast<int *>(outputIndicesDeviceAddr), static_cast<float *>(outputValuesDeviceAddr), static_cast<int *>(mParams.mBufferIndices), static_cast<float *>(mParams.mBufferValues), mParams.mNumK, mParams.mNumBlockPerRow, mParams.mDescendFlag);
 | |
|         checkKernelErrors;
 | |
|     }
 | |
| 
 | |
|     return NO_ERROR;
 | |
| }
 | |
| 
 | |
| 
 | |
| class TopKV2Creator : public CUDABackend::Creator {
 | |
| public:
 | |
|     virtual Execution* onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
 | |
|                                 const MNN::Op* op, Backend* backend) const override {
 | |
|         return new TopKV2Execution(op, backend);
 | |
|     }
 | |
| };
 | |
| 
 | |
| 
 | |
| static CUDACreatorRegister<TopKV2Creator> __init(OpType_TopKV2);
 | |
| 
 | |
| 
 | |
| }
 | |
| }
 |