2025-04-28 11:38:44 +08:00
|
|
|
#include "opencl_source_map.hpp"
|
|
|
|
namespace MNN {
|
|
|
|
const char* softmax =
|
|
|
|
"#ifdef MNN_SUPPORT_FP16\n"
|
|
|
|
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
|
|
|
|
"#endif\n"
|
|
|
|
"#define EXP exp\n"
|
|
|
|
"#define GLOBAL_SIZE_3_DIMS "" __private const int global_size_dim0,__private const int global_size_dim1,__private const int global_size_dim2,\n"
|
|
|
|
"#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n"
|
|
|
|
"__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
|
|
|
|
"__kernel void softmax_channel(GLOBAL_SIZE_3_DIMS __read_only image2d_t input,__write_only image2d_t output,\n"
|
|
|
|
" __private const int remain_channels,__private const int4 shape // NCHW\n"
|
|
|
|
" ) {\n"
|
|
|
|
" const int x=get_global_id(0);\n"
|
|
|
|
" const int w=get_global_id(1);\n"
|
|
|
|
" const int bh=get_global_id(2);\n"
|
|
|
|
" DEAL_NON_UNIFORM_DIM3(x,w,bh);\n"
|
|
|
|
"#if SOFTMAX_LOCAL_SIZE >= 4\n"
|
|
|
|
" int lid=get_local_id(0);\n"
|
2025-04-30 18:00:38 +08:00
|
|
|
" FLOAT4 local sum_mnn[SOFTMAX_LOCAL_SIZE];\n"
|
|
|
|
" FLOAT4 local max_mnn[SOFTMAX_LOCAL_SIZE];\n"
|
2025-04-28 11:38:44 +08:00
|
|
|
" FLOAT4 maxValue=(FLOAT4)-FLT_MAX;\n"
|
|
|
|
" for (int i=lid; i<shape.y-1; i+=SOFTMAX_LOCAL_SIZE) {\n"
|
|
|
|
" maxValue=fmax(maxValue,RI_F(input,SAMPLER,(int2)(w+i*shape.w,bh)));\n"
|
|
|
|
" }\n"
|
2025-04-30 18:00:38 +08:00
|
|
|
" max_mnn[lid]=maxValue;\n"
|
2025-04-28 11:38:44 +08:00
|
|
|
" barrier(CLK_LOCAL_MEM_FENCE);\n"
|
|
|
|
" for(int i=SOFTMAX_LOCAL_SIZE/2; i>0; i /= 2){\n"
|
|
|
|
" if (lid<i)\n"
|
2025-04-30 18:00:38 +08:00
|
|
|
" max_mnn[lid]=fmax(max_mnn[lid],max_mnn[lid+i]);\n"
|
2025-04-28 11:38:44 +08:00
|
|
|
" barrier(CLK_LOCAL_MEM_FENCE);\n"
|
|
|
|
" }\n"
|
2025-04-30 18:00:38 +08:00
|
|
|
" maxValue=max_mnn[0];\n"
|
2025-04-28 11:38:44 +08:00
|
|
|
" maxValue.x=fmax(maxValue.x,maxValue.y);\n"
|
|
|
|
" maxValue.x=fmax(maxValue.x,maxValue.z);\n"
|
|
|
|
" maxValue.x=fmax(maxValue.x,maxValue.w);\n"
|
|
|
|
" FLOAT4 input_data=RI_F(input,SAMPLER,(int2)(w+(shape.y-1)*shape.w ,bh));\n"
|
|
|
|
" if (remain_channels == 0) {\n"
|
|
|
|
" maxValue.x=fmax(maxValue.x,input_data.x);\n"
|
|
|
|
" maxValue.x=fmax(maxValue.x,input_data.y);\n"
|
|
|
|
" maxValue.x=fmax(maxValue.x,input_data.z);\n"
|
|
|
|
" maxValue.x=fmax(maxValue.x,input_data.w);\n"
|
|
|
|
" } else if (remain_channels == 1) {\n"
|
|
|
|
" maxValue.x=fmax(maxValue.x,input_data.z);\n"
|
|
|
|
" maxValue.x=fmax(maxValue.x,input_data.y);\n"
|
|
|
|
" maxValue.x=fmax(maxValue.x,input_data.x);\n"
|
|
|
|
" } else if (remain_channels == 2) {\n"
|
|
|
|
" maxValue.x=fmax(maxValue.x,input_data.y);\n"
|
|
|
|
" maxValue.x=fmax(maxValue.x,input_data.x);\n"
|
|
|
|
" } else if (remain_channels == 3) {\n"
|
|
|
|
" maxValue.x=fmax(maxValue.x,input_data.x);\n"
|
|
|
|
" }\n"
|
|
|
|
" FLOAT4 sumValue=(FLOAT4)0;\n"
|
|
|
|
" for (int i=lid; i<shape.y-1; i+=SOFTMAX_LOCAL_SIZE) {\n"
|
|
|
|
" sumValue += exp(RI_F(input,SAMPLER,(int2)(w+i*shape.w,bh))-(FLOAT4)maxValue.x);\n"
|
|
|
|
" }\n"
|
2025-04-30 18:00:38 +08:00
|
|
|
" sum_mnn[lid]=sumValue;\n"
|
2025-04-28 11:38:44 +08:00
|
|
|
" barrier(CLK_LOCAL_MEM_FENCE);\n"
|
|
|
|
" for(int i=SOFTMAX_LOCAL_SIZE/2; i>0; i /= 2){\n"
|
|
|
|
" if (lid<i)\n"
|
2025-04-30 18:00:38 +08:00
|
|
|
" sum_mnn[lid]=sum_mnn[lid]+sum_mnn[lid+i];\n"
|
2025-04-28 11:38:44 +08:00
|
|
|
" barrier(CLK_LOCAL_MEM_FENCE);\n"
|
|
|
|
" }\n"
|
2025-04-30 18:00:38 +08:00
|
|
|
" sumValue=sum_mnn[0];\n"
|
2025-04-28 11:38:44 +08:00
|
|
|
" sumValue.x=sumValue.x+sumValue.y+sumValue.z+sumValue.w;\n"
|
|
|
|
" \n"
|
|
|
|
" \n"
|
|
|
|
" input_data -= maxValue.x;\n"
|
|
|
|
" if (remain_channels == 0) {\n"
|
|
|
|
" sumValue.x += exp(input_data.w);\n"
|
|
|
|
" sumValue.x += exp(input_data.z);\n"
|
|
|
|
" sumValue.x += exp(input_data.y);\n"
|
|
|
|
" sumValue.x += exp(input_data.x);\n"
|
|
|
|
" } else if (remain_channels == 1) {\n"
|
|
|
|
" sumValue.x += exp(input_data.z);\n"
|
|
|
|
" sumValue.x += exp(input_data.y);\n"
|
|
|
|
" sumValue.x += exp(input_data.x);\n"
|
|
|
|
" } else if (remain_channels == 2) {\n"
|
|
|
|
" sumValue.x += exp(input_data.y);\n"
|
|
|
|
" sumValue.x += exp(input_data.x);\n"
|
|
|
|
" } else if (remain_channels == 3) {\n"
|
|
|
|
" sumValue.x += exp(input_data.x);\n"
|
|
|
|
" }\n"
|
|
|
|
" for(int i=lid; i<shape.y; i+=SOFTMAX_LOCAL_SIZE){\n"
|
|
|
|
" FLOAT4 value=exp(RI_F(input,SAMPLER,(int2)(w+i*shape.w,bh))-maxValue.x)/sumValue.x;\n"
|
|
|
|
" WI_F(output,(int2)(w+i*shape.w,bh),value);\n"
|
|
|
|
" }\n"
|
|
|
|
"#else\n"
|
|
|
|
" FLOAT4 maxValue=(FLOAT4)-FLT_MAX;\n"
|
|
|
|
" for (int i=0; i<shape.y-1; i++) {\n"
|
|
|
|
" maxValue=fmax(maxValue,RI_F(input,SAMPLER,(int2)(w+i*shape.w,bh)));\n"
|
|
|
|
" }\n"
|
|
|
|
" \n"
|
|
|
|
" maxValue.x=fmax(maxValue.x,maxValue.y);\n"
|
|
|
|
" maxValue.x=fmax(maxValue.x,maxValue.z);\n"
|
|
|
|
" maxValue.x=fmax(maxValue.x,maxValue.w);\n"
|
|
|
|
" FLOAT4 input_data=RI_F(input,SAMPLER,(int2)(w+(shape.y-1)*shape.w ,bh));\n"
|
|
|
|
" if (remain_channels == 0) {\n"
|
|
|
|
" maxValue.x=fmax(maxValue.x,input_data.x);\n"
|
|
|
|
" maxValue.x=fmax(maxValue.x,input_data.y);\n"
|
|
|
|
" maxValue.x=fmax(maxValue.x,input_data.z);\n"
|
|
|
|
" maxValue.x=fmax(maxValue.x,input_data.w);\n"
|
|
|
|
" } else if (remain_channels == 1) {\n"
|
|
|
|
" maxValue.x=fmax(maxValue.x,input_data.z);\n"
|
|
|
|
" maxValue.x=fmax(maxValue.x,input_data.y);\n"
|
|
|
|
" maxValue.x=fmax(maxValue.x,input_data.x);\n"
|
|
|
|
" } else if (remain_channels == 2) {\n"
|
|
|
|
" maxValue.x=fmax(maxValue.x,input_data.y);\n"
|
|
|
|
" maxValue.x=fmax(maxValue.x,input_data.x);\n"
|
|
|
|
" } else if (remain_channels == 3) {\n"
|
|
|
|
" maxValue.x=fmax(maxValue.x,input_data.x);\n"
|
|
|
|
" }\n"
|
|
|
|
" FLOAT4 sumValue=(FLOAT4)0;\n"
|
|
|
|
" for (int i=0; i<shape.y-1; i++) {\n"
|
|
|
|
" sumValue += exp(RI_F(input,SAMPLER,(int2)(w+i*shape.w,bh))-(FLOAT4)maxValue.x);\n"
|
|
|
|
" }\n"
|
|
|
|
" sumValue.x=sumValue.x+sumValue.y+sumValue.z+sumValue.w;\n"
|
|
|
|
" input_data -= maxValue.x;\n"
|
|
|
|
" if (remain_channels == 0) {\n"
|
|
|
|
" sumValue.x += exp(input_data.w);\n"
|
|
|
|
" sumValue.x += exp(input_data.z);\n"
|
|
|
|
" sumValue.x += exp(input_data.y);\n"
|
|
|
|
" sumValue.x += exp(input_data.x);\n"
|
|
|
|
" } else if (remain_channels == 1) {\n"
|
|
|
|
" sumValue.x += exp(input_data.z);\n"
|
|
|
|
" sumValue.x += exp(input_data.y);\n"
|
|
|
|
" sumValue.x += exp(input_data.x);\n"
|
|
|
|
" } else if (remain_channels == 2) {\n"
|
|
|
|
" sumValue.x += exp(input_data.y);\n"
|
|
|
|
" sumValue.x += exp(input_data.x);\n"
|
|
|
|
" } else if (remain_channels == 3) {\n"
|
|
|
|
" sumValue.x += exp(input_data.x);\n"
|
|
|
|
" }\n"
|
|
|
|
" for(int i=0; i<shape.y; i++){\n"
|
|
|
|
" FLOAT4 value=exp(RI_F(input,SAMPLER,(int2)(w+i*shape.w,bh))-maxValue.x)/sumValue.x;\n"
|
|
|
|
" WI_F(output,(int2)(w+i*shape.w,bh),value);\n"
|
|
|
|
" }\n"
|
|
|
|
"#endif\n"
|
|
|
|
"}\n"
|
|
|
|
"__kernel void softmax_height(GLOBAL_SIZE_3_DIMS __read_only image2d_t input,__write_only image2d_t output,\n"
|
|
|
|
" __private const int remain_channels,__private const int4 shape // NCHW\n"
|
|
|
|
" ) {\n"
|
|
|
|
" const int x=get_global_id(0);\n"
|
|
|
|
" const int wc=get_global_id(1);\n"
|
|
|
|
" const int b=get_global_id(2);\n"
|
|
|
|
" DEAL_NON_UNIFORM_DIM3(x,wc,b);\n"
|
|
|
|
"#if SOFTMAX_LOCAL_SIZE >= 4\n"
|
|
|
|
" int lid=get_local_id(0);\n"
|
2025-04-30 18:00:38 +08:00
|
|
|
" FLOAT4 local sum_mnn[SOFTMAX_LOCAL_SIZE];\n"
|
|
|
|
" FLOAT4 local max_mnn[SOFTMAX_LOCAL_SIZE];\n"
|
2025-04-28 11:38:44 +08:00
|
|
|
" /*Compute Max */\n"
|
|
|
|
" FLOAT4 maxValue=(FLOAT4)(-FLT_MAX);\n"
|
|
|
|
" for (int i=lid; i<shape.z; i+=SOFTMAX_LOCAL_SIZE) {\n"
|
|
|
|
" maxValue=fmax(maxValue,RI_F(input,SAMPLER,(int2)(wc,b*shape.z+i)));\n"
|
|
|
|
" }\n"
|
2025-04-30 18:00:38 +08:00
|
|
|
" max_mnn[lid]=maxValue;\n"
|
2025-04-28 11:38:44 +08:00
|
|
|
" barrier(CLK_LOCAL_MEM_FENCE);\n"
|
|
|
|
" for(int i=SOFTMAX_LOCAL_SIZE/2; i>0; i /= 2){\n"
|
|
|
|
" if (lid<i)\n"
|
2025-04-30 18:00:38 +08:00
|
|
|
" max_mnn[lid]=fmax(max_mnn[lid],max_mnn[lid+i]);\n"
|
2025-04-28 11:38:44 +08:00
|
|
|
" barrier(CLK_LOCAL_MEM_FENCE);\n"
|
|
|
|
" }\n"
|
2025-04-30 18:00:38 +08:00
|
|
|
" maxValue=max_mnn[0];\n"
|
2025-04-28 11:38:44 +08:00
|
|
|
" \n"
|
|
|
|
" /*Compute Exp Sum*/\n"
|
|
|
|
" FLOAT4 sumValue=(FLOAT4)0;\n"
|
|
|
|
" for (int i=lid; i<shape.z; i+=SOFTMAX_LOCAL_SIZE) {\n"
|
|
|
|
" sumValue += exp(RI_F(input,SAMPLER,(int2)(wc,b*shape.z+i))-maxValue);\n"
|
|
|
|
" }\n"
|
2025-04-30 18:00:38 +08:00
|
|
|
" sum_mnn[lid]=sumValue;\n"
|
2025-04-28 11:38:44 +08:00
|
|
|
" barrier(CLK_LOCAL_MEM_FENCE);\n"
|
|
|
|
" for(int i=SOFTMAX_LOCAL_SIZE/2; i>0; i /= 2){\n"
|
|
|
|
" if (lid<i)\n"
|
2025-04-30 18:00:38 +08:00
|
|
|
" sum_mnn[lid]=sum_mnn[lid]+sum_mnn[lid+i];\n"
|
2025-04-28 11:38:44 +08:00
|
|
|
" barrier(CLK_LOCAL_MEM_FENCE);\n"
|
|
|
|
" }\n"
|
2025-04-30 18:00:38 +08:00
|
|
|
" sumValue=sum_mnn[0];\n"
|
2025-04-28 11:38:44 +08:00
|
|
|
" \n"
|
|
|
|
" /*Compute Result */\n"
|
|
|
|
" for (int i=lid; i<shape.z; i+=SOFTMAX_LOCAL_SIZE) {\n"
|
|
|
|
" FLOAT4 value=exp(RI_F(input,SAMPLER,(int2)(wc,b*shape.z+i))-maxValue)/sumValue;\n"
|
|
|
|
" WI_F(output,(int2)(wc,b*shape.z+i),value);\n"
|
|
|
|
" }\n"
|
|
|
|
"#else\n"
|
|
|
|
" /*Compute Max */\n"
|
|
|
|
" FLOAT4 maxValue=(FLOAT4)(-FLT_MAX);\n"
|
|
|
|
" for (int i=0; i<shape.z; i++) {\n"
|
|
|
|
" maxValue=fmax(maxValue,RI_F(input,SAMPLER,(int2)(wc,b*shape.z+i)));\n"
|
|
|
|
" }\n"
|
|
|
|
" \n"
|
|
|
|
" /*Compute Exp Sum*/\n"
|
|
|
|
" FLOAT4 sumValue=(FLOAT4)0;\n"
|
|
|
|
" for (int i=0; i<shape.z; i++) {\n"
|
|
|
|
" sumValue += exp(RI_F(input,SAMPLER,(int2)(wc,b*shape.z+i))-maxValue);\n"
|
|
|
|
" }\n"
|
|
|
|
" \n"
|
|
|
|
" /*Compute Result */\n"
|
|
|
|
" for (int i=0; i<shape.z; i++) {\n"
|
|
|
|
" FLOAT4 value=exp(RI_F(input,SAMPLER,(int2)(wc,b*shape.z+i))-maxValue)/sumValue;\n"
|
|
|
|
" WI_F(output,(int2)(wc,b*shape.z+i),value);\n"
|
|
|
|
" }\n"
|
|
|
|
"#endif\n"
|
|
|
|
"}\n"
|
|
|
|
"__kernel void softmax_width(GLOBAL_SIZE_3_DIMS __read_only image2d_t input,__write_only image2d_t output,\n"
|
|
|
|
" __private const int remain_channels,__private const int4 shape // NCHW\n"
|
|
|
|
" ) {\n"
|
|
|
|
" const int x=get_global_id(0);\n"
|
|
|
|
" const int c=get_global_id(1);\n"
|
|
|
|
" const int bh=get_global_id(2);\n"
|
|
|
|
" DEAL_NON_UNIFORM_DIM3(x,c,bh);\n"
|
|
|
|
"#if SOFTMAX_LOCAL_SIZE >= 4\n"
|
|
|
|
" int lid=get_local_id(0);\n"
|
2025-04-30 18:00:38 +08:00
|
|
|
" FLOAT4 local sum_mnn[SOFTMAX_LOCAL_SIZE];\n"
|
|
|
|
" FLOAT4 local max_mnn[SOFTMAX_LOCAL_SIZE];\n"
|
2025-04-28 11:38:44 +08:00
|
|
|
" \n"
|
|
|
|
" /*Compute Max */\n"
|
|
|
|
" FLOAT4 maxValue=(FLOAT4)(-FLT_MAX);\n"
|
|
|
|
" for (int i=lid; i<shape.w; i+=SOFTMAX_LOCAL_SIZE) {\n"
|
|
|
|
" maxValue=fmax(maxValue,RI_F(input,SAMPLER,(int2)(c*shape.w+i,bh)));\n"
|
|
|
|
" }\n"
|
2025-04-30 18:00:38 +08:00
|
|
|
" max_mnn[lid]=maxValue;\n"
|
2025-04-28 11:38:44 +08:00
|
|
|
" barrier(CLK_LOCAL_MEM_FENCE);\n"
|
|
|
|
" for(int i=SOFTMAX_LOCAL_SIZE/2; i>0; i /= 2){\n"
|
|
|
|
" if (lid<i)\n"
|
2025-04-30 18:00:38 +08:00
|
|
|
" max_mnn[lid]=fmax(max_mnn[lid],max_mnn[lid+i]);\n"
|
2025-04-28 11:38:44 +08:00
|
|
|
" barrier(CLK_LOCAL_MEM_FENCE);\n"
|
|
|
|
" }\n"
|
2025-04-30 18:00:38 +08:00
|
|
|
" maxValue=max_mnn[0];\n"
|
2025-04-28 11:38:44 +08:00
|
|
|
" \n"
|
|
|
|
" /*Compute Exp Sum*/\n"
|
|
|
|
" FLOAT4 sumValue=(FLOAT4)0;\n"
|
|
|
|
" for (int i=lid; i<shape.w; i+=SOFTMAX_LOCAL_SIZE) {\n"
|
|
|
|
" sumValue += exp(RI_F(input,SAMPLER,(int2)(c*shape.w+i,bh))-maxValue);\n"
|
|
|
|
" }\n"
|
2025-04-30 18:00:38 +08:00
|
|
|
" sum_mnn[lid]=sumValue;\n"
|
2025-04-28 11:38:44 +08:00
|
|
|
" barrier(CLK_LOCAL_MEM_FENCE);\n"
|
|
|
|
" for(int i=SOFTMAX_LOCAL_SIZE/2; i>0; i /= 2){\n"
|
|
|
|
" if (lid<i)\n"
|
2025-04-30 18:00:38 +08:00
|
|
|
" sum_mnn[lid]=sum_mnn[lid]+sum_mnn[lid+i];\n"
|
2025-04-28 11:38:44 +08:00
|
|
|
" barrier(CLK_LOCAL_MEM_FENCE);\n"
|
|
|
|
" }\n"
|
2025-04-30 18:00:38 +08:00
|
|
|
" sumValue=sum_mnn[0];\n"
|
2025-04-28 11:38:44 +08:00
|
|
|
" \n"
|
|
|
|
" /*Compute Result */\n"
|
|
|
|
" for (int i=lid; i<shape.w; i+=SOFTMAX_LOCAL_SIZE) {\n"
|
|
|
|
" FLOAT4 value=exp(RI_F(input,SAMPLER,(int2)(c*shape.w+i,bh))-maxValue)/sumValue;\n"
|
|
|
|
" WI_F(output,(int2)(c*shape.w+i,bh),value);\n"
|
|
|
|
" }\n"
|
|
|
|
"#else\n"
|
|
|
|
" /*Compute Max */\n"
|
|
|
|
" FLOAT4 maxValue=(FLOAT4)(-FLT_MAX);\n"
|
|
|
|
" for (int i=0; i<shape.w; i++) {\n"
|
|
|
|
" maxValue=fmax(maxValue,RI_F(input,SAMPLER,(int2)(c*shape.w+i,bh)));\n"
|
|
|
|
" }\n"
|
|
|
|
" /*Compute Exp Sum*/\n"
|
|
|
|
" FLOAT4 sumValue=(FLOAT4)0;\n"
|
|
|
|
" for (int i=0; i<shape.w; i++) {\n"
|
|
|
|
" sumValue += exp(RI_F(input,SAMPLER,(int2)(c*shape.w+i,bh))-maxValue);\n"
|
|
|
|
" }\n"
|
|
|
|
" \n"
|
|
|
|
" /*Compute Result */\n"
|
|
|
|
" for (int i=0; i<shape.w; i++) {\n"
|
|
|
|
" FLOAT4 value=exp(RI_F(input,SAMPLER,(int2)(c*shape.w+i,bh))-maxValue)/sumValue;\n"
|
|
|
|
" WI_F(output,(int2)(c*shape.w+i,bh),value);\n"
|
|
|
|
" }\n"
|
|
|
|
"#endif\n"
|
|
|
|
"}\n"
|
|
|
|
;
|
|
|
|
}
|