mirror of https://github.com/alibaba/MNN.git
139 lines
5.0 KiB
C++
139 lines
5.0 KiB
C++
#include "opencl_source_map.hpp"
|
|
namespace MNN {
|
|
#ifndef MNN_OPENCL_BUFFER_CLOSED
|
|
const char* argmax_buf =
|
|
"#ifdef MNN_SUPPORT_FP16\n"
|
|
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
|
|
"#endif\n"
|
|
"#define GLOBAL_SIZE_3_DIMS ""__private const int global_size_dim0,__private const int global_size_dim1,__private const int global_size_dim2,\n"
|
|
"#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n"
|
|
"#define ARGMAX_SELECT(A, B, C, D) "" if(A.x < B.x){ A.x = B.x; C.x = D; } "" if(A.y < B.y){ A.y = B.y; C.y = D; } "" if(A.z < B.z){ A.z = B.z; C.z = D; } "" if(A.w<B.w){ A.w=B.w; C.w=D; } \n"
|
|
"#define ARGMIN_SELECT(A, B, C, D) "" if(A.x > B.x){ A.x = B.x; C.x = D; } "" if(A.y > B.y){ A.y = B.y; C.y = D; } "" if(A.z > B.z){ A.z = B.z; C.z = D; } "" if(A.w>B.w){ A.w=B.w; C.w=D; } \n"
|
|
"__kernel void argmax_buf(GLOBAL_SIZE_3_DIMS\n"
|
|
" __global const FLOAT* input,\n"
|
|
" __global int* output,\n"
|
|
" __private const int inside,\n"
|
|
" __private const int outside,\n"
|
|
" __private const int dim){\n"
|
|
" const int x=get_global_id(0);\n"
|
|
" const int y=get_global_id(1); // inside\n"
|
|
" const int z=get_global_id(2); // outside\n"
|
|
" \n"
|
|
" DEAL_NON_UNIFORM_DIM3(x,y,z);\n"
|
|
" int index=0;\n"
|
|
"#ifdef ARGMAX\n"
|
|
" FLOAT maxValue=(FLOAT)-FLT_MAX;\n"
|
|
"#else\n"
|
|
"FLOAT maxValue=(FLOAT)FLT_MAX;\n"
|
|
"#endif\n"
|
|
" const int offset=z*dim*inside+y;\n"
|
|
"#if ARGMAX_LOCAL_SIZE >= 4\n"
|
|
" int lid=get_local_id(0);\n"
|
|
" FLOAT local reduce[ARGMAX_LOCAL_SIZE];\n"
|
|
" int local index_reduce[ARGMAX_LOCAL_SIZE];\n"
|
|
" \n"
|
|
" for (int i=lid; i<dim; i+=ARGMAX_LOCAL_SIZE) {\n"
|
|
" FLOAT value=input[offset+i*inside];\n"
|
|
"#ifdef ARGMAX\n"
|
|
" if(maxValue<value){ maxValue=value; index=i; }\n"
|
|
"#else\n"
|
|
" if(maxValue>value){ maxValue=value; index=i; }\n"
|
|
"#endif\n"
|
|
" }\n"
|
|
" reduce[lid]=maxValue;\n"
|
|
" index_reduce[lid]=index;\n"
|
|
" barrier(CLK_LOCAL_MEM_FENCE);\n"
|
|
" for(int i=ARGMAX_LOCAL_SIZE/2; i>0; i /= 2){\n"
|
|
" if (lid<i){\n"
|
|
"#ifdef ARGMAX\n"
|
|
" if(reduce[lid]<reduce[lid+i]){reduce[lid]=reduce[lid+i]; index_reduce[lid]=index_reduce[lid+i];}\n"
|
|
"#else\n"
|
|
" if(reduce[lid]>reduce[lid+i]){reduce[lid]=reduce[lid+i]; index_reduce[lid]=index_reduce[lid+i];}\n"
|
|
"#endif\n"
|
|
" }\n"
|
|
" barrier(CLK_LOCAL_MEM_FENCE);\n"
|
|
" }\n"
|
|
" if(lid == 0){\n"
|
|
" output[z*inside+y]=index_reduce[0];\n"
|
|
" }\n"
|
|
"#else\n"
|
|
" for(int i=0; i<dim; ++i){\n"
|
|
" FLOAT value=input[+offset+i*inside];\n"
|
|
"#ifdef ARGMAX\n"
|
|
" if(maxValue<value){ maxValue=value; index=i; }\n"
|
|
"#else\n"
|
|
" if(maxValue>value){ maxValue=value; index=i; }\n"
|
|
"#endif\n"
|
|
" }\n"
|
|
" output[z*inside+y]=index;\n"
|
|
"#endif\n"
|
|
"}\n"
|
|
"__kernel void argmax_v4_buf(GLOBAL_SIZE_3_DIMS\n"
|
|
" __global const FLOAT* input,\n"
|
|
" __global int* output,\n"
|
|
" __private const int inside,\n"
|
|
" __private const int outside,\n"
|
|
" __private const int dim){\n"
|
|
" const int x=get_global_id(0);\n"
|
|
" const int y=get_global_id(1) << 2; // inside\n"
|
|
" const int z=get_global_id(2); // outside\n"
|
|
" \n"
|
|
" DEAL_NON_UNIFORM_DIM3(x,y,z);\n"
|
|
" int4 index=0;\n"
|
|
"#ifdef ARGMAX\n"
|
|
" FLOAT4 maxValue=(FLOAT4)-FLT_MAX;\n"
|
|
"#else\n"
|
|
" FLOAT4 maxValue=(FLOAT4)FLT_MAX;\n"
|
|
"#endif\n"
|
|
" const int offset=z*dim*inside+y;\n"
|
|
"#if ARGMAX_LOCAL_SIZE >= 4\n"
|
|
" int lid=get_local_id(0);\n"
|
|
" FLOAT4 local reduce[ARGMAX_LOCAL_SIZE];\n"
|
|
" int4 local index_reduce[ARGMAX_LOCAL_SIZE];\n"
|
|
" \n"
|
|
" for (int i=lid; i<dim; i+=ARGMAX_LOCAL_SIZE) {\n"
|
|
" FLOAT4 value=vload4(0,input+offset+i*inside);\n"
|
|
"#ifdef ARGMAX\n"
|
|
" ARGMAX_SELECT(maxValue,value,index,i);\n"
|
|
"#else\n"
|
|
" ARGMIN_SELECT(maxValue,value,index,i);\n"
|
|
"#endif\n"
|
|
" }\n"
|
|
" reduce[lid]=maxValue;\n"
|
|
" index_reduce[lid]=index;\n"
|
|
" barrier(CLK_LOCAL_MEM_FENCE);\n"
|
|
" for(int i=ARGMAX_LOCAL_SIZE/2; i>0; i /= 2){\n"
|
|
" if (lid<i){\n"
|
|
"#ifdef ARGMAX\n"
|
|
" if(reduce[lid].x<reduce[lid+i].x){reduce[lid].x=reduce[lid+i].x; index_reduce[lid].x=index_reduce[lid+i].x;}\n"
|
|
" if(reduce[lid].y<reduce[lid+i].y){reduce[lid].y=reduce[lid+i].y; index_reduce[lid].y=index_reduce[lid+i].y;}\n"
|
|
" if(reduce[lid].z<reduce[lid+i].z){reduce[lid].z=reduce[lid+i].z; index_reduce[lid].z=index_reduce[lid+i].z;}\n"
|
|
" if(reduce[lid].w<reduce[lid+i].w){reduce[lid].w=reduce[lid+i].w; index_reduce[lid].w=index_reduce[lid+i].w;}\n"
|
|
"#else\n"
|
|
" if(reduce[lid].x>reduce[lid+i].x){reduce[lid].x=reduce[lid+i].x; index_reduce[lid].x=index_reduce[lid+i].x;}\n"
|
|
" if(reduce[lid].y>reduce[lid+i].y){reduce[lid].y=reduce[lid+i].y; index_reduce[lid].y=index_reduce[lid+i].y;}\n"
|
|
" if(reduce[lid].z>reduce[lid+i].z){reduce[lid].z=reduce[lid+i].z; index_reduce[lid].z=index_reduce[lid+i].z;}\n"
|
|
" if(reduce[lid].w>reduce[lid+i].w){reduce[lid].w=reduce[lid+i].w; index_reduce[lid].w=index_reduce[lid+i].w;}\n"
|
|
"#endif\n"
|
|
" }\n"
|
|
" barrier(CLK_LOCAL_MEM_FENCE);\n"
|
|
" }\n"
|
|
" if(lid == 0){\n"
|
|
" vstore4(index_reduce[0],0,output+z*inside+y);\n"
|
|
" }\n"
|
|
"#else\n"
|
|
" for(int i=0; i<dim; ++i){\n"
|
|
" FLOAT4 value=vload4(0,input+offset+i*inside);\n"
|
|
"#ifdef ARGMAX\n"
|
|
" ARGMAX_SELECT(maxValue,value,index,i);\n"
|
|
"#else\n"
|
|
" ARGMIN_SELECT(maxValue,value,index,i);\n"
|
|
"#endif\n"
|
|
" }\n"
|
|
" vstore4(index,0,output+z*inside+y);\n"
|
|
"#endif\n"
|
|
"}\n"
|
|
;
|
|
#endif
|
|
}
|