2025-04-28 11:38:44 +08:00
|
|
|
#include "opencl_source_map.hpp"
|
|
|
|
namespace MNN {
|
|
|
|
const char* glmem_convert =
|
|
|
|
"#ifdef MNN_SUPPORT_FP16\n"
|
|
|
|
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
|
|
|
|
"#endif\n"
|
|
|
|
"#define GLOBAL_SIZE_3_DIMS __private const int global_size_dim0,__private const int global_size_dim1,__private const int global_size_dim2,\n"
|
|
|
|
"#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n"
|
|
|
|
"#define MNN_DATA_FORMAT_NCHW 0\n"
|
|
|
|
"#define MNN_DATA_FORMAT_NHWC 1\n"
|
|
|
|
"#define MNN_DATA_FORMAT_NC4HW4 2\n"
|
|
|
|
"#define MNN_DATA_FORMAT_C4NHW4 3\n"
|
|
|
|
"#define __CAT(x,y) x##y\n"
|
|
|
|
"#define CAT(x,y) __CAT(x,y)\n"
|
|
|
|
"#define OUTPUT_TYPE2 CAT(OUTPUT_TYPE,2)\n"
|
|
|
|
"#define OUTPUT_TYPE3 CAT(OUTPUT_TYPE,3)\n"
|
|
|
|
"__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
|
|
|
|
"#ifdef SHARED_TO_CL\n"
|
|
|
|
"__kernel void gl_to_cl(GLOBAL_SIZE_3_DIMS\n"
|
|
|
|
" __global uchar *input_ptr,\n"
|
|
|
|
" #ifdef USE_IMAGE\n"
|
|
|
|
" __write_only image2d_t output_ptr,\n"
|
|
|
|
" #else\n"
|
|
|
|
" __global OUTPUT_TYPE *output_ptr,\n"
|
|
|
|
" #endif\n"
|
|
|
|
" __private const int4 shape // N C H W\n"
|
|
|
|
") {\n"
|
|
|
|
" int wblock=get_global_id(0);\n"
|
|
|
|
" int cblock=get_global_id(1);\n"
|
|
|
|
" int nh=get_global_id(2);\n"
|
|
|
|
" DEAL_NON_UNIFORM_DIM3(wblock,cblock,nh);\n"
|
|
|
|
" const int w=wblock << 2;\n"
|
|
|
|
" const int h=nh % shape.z;\n"
|
|
|
|
" const int c=cblock << 2;\n"
|
|
|
|
" const int n=nh/shape.z;\n"
|
|
|
|
" \n"
|
|
|
|
" int idx=c*shape.w+w; // c/4*w\n"
|
|
|
|
" int idy=nh; // n*h\n"
|
|
|
|
" const int offset=idy*shape.w*4;\n"
|
|
|
|
" OUTPUT_TYPE4 in0=CONVERT_OUTPUT4(vload4(idx,input_ptr+offset));\n"
|
|
|
|
" OUTPUT_TYPE4 in1=CONVERT_OUTPUT4(vload4(idx+1,input_ptr+offset));\n"
|
|
|
|
" OUTPUT_TYPE4 in2=CONVERT_OUTPUT4(vload4(idx+2,input_ptr+offset));\n"
|
|
|
|
" OUTPUT_TYPE4 in3=CONVERT_OUTPUT4(vload4(idx+3,input_ptr+offset));\n"
|
|
|
|
"#ifdef USE_IMAGE\n"
|
|
|
|
" WI_DATA(output_ptr,(int2)(idx,idy),in0);\n"
|
|
|
|
" if(w+1 >= shape.w) return;\n"
|
|
|
|
" WI_DATA(output_ptr,(int2)(idx+1,idy),in1);\n"
|
|
|
|
" if(w+2 >= shape.w) return;\n"
|
|
|
|
" WI_DATA(output_ptr,(int2)(idx+2,idy),in2);\n"
|
|
|
|
" if(w+3 >= shape.w) return;\n"
|
|
|
|
" WI_DATA(output_ptr,(int2)(idx+3,idy),in3);\n"
|
|
|
|
"#else\n"
|
|
|
|
" #if OUTPUT_FORMAT == MNN_DATA_FORMAT_NCHW\n"
|
|
|
|
" int output_offset=((n*shape.y+c)*shape.z+h)*shape.w+w;\n"
|
|
|
|
" int stride=shape.z*shape.w;\n"
|
|
|
|
" int remain=shape.w-w;\n"
|
|
|
|
" if(remain >= 4){\n"
|
|
|
|
" vstore4((OUTPUT_TYPE4)(in0.x,in1.x,in2.x,in3.x),0,output_ptr+output_offset);\n"
|
|
|
|
" if(c+1 >= shape.y) return;\n"
|
|
|
|
" vstore4((OUTPUT_TYPE4)(in0.y,in1.y,in2.y,in3.y),0,output_ptr+output_offset+stride);\n"
|
|
|
|
" if(c+2 >= shape.y) return;\n"
|
|
|
|
" vstore4((OUTPUT_TYPE4)(in0.z,in1.z,in2.z,in3.z),0,output_ptr+output_offset+stride+stride);\n"
|
|
|
|
" if(c+3 >= shape.y) return;\n"
|
|
|
|
" vstore4((OUTPUT_TYPE4)(in0.w,in1.w,in2.w,in3.w),0,output_ptr+output_offset+stride+stride+stride);\n"
|
|
|
|
" } else if(remain == 3){\n"
|
|
|
|
" vstore3((OUTPUT_TYPE3)(in0.x,in1.x,in2.x),0,output_ptr+output_offset);\n"
|
|
|
|
" if(c+1 >= shape.y) return;\n"
|
|
|
|
" vstore3((OUTPUT_TYPE3)(in0.y,in1.y,in2.y),0,output_ptr+output_offset+stride);\n"
|
|
|
|
" if(c+2 >= shape.y) return;\n"
|
|
|
|
" vstore3((OUTPUT_TYPE3)(in0.z,in1.z,in2.z),0,output_ptr+output_offset+stride+stride);\n"
|
|
|
|
" if(c+3 >= shape.y) return;\n"
|
|
|
|
" vstore3((OUTPUT_TYPE3)(in0.w,in1.w,in2.w),0,output_ptr+output_offset+stride+stride+stride);\n"
|
|
|
|
" } else if(remain == 2){\n"
|
|
|
|
" vstore2((OUTPUT_TYPE2)(in0.x,in1.x),0,output_ptr+output_offset);\n"
|
|
|
|
" if(c+1 >= shape.y) return;\n"
|
|
|
|
" vstore2((OUTPUT_TYPE2)(in0.y,in1.y),0,output_ptr+output_offset+stride);\n"
|
|
|
|
" if(c+2 >= shape.y) return;\n"
|
|
|
|
" vstore2((OUTPUT_TYPE2)(in0.z,in1.z),0,output_ptr+output_offset+stride+stride);\n"
|
|
|
|
" if(c+3 >= shape.y) return;\n"
|
|
|
|
" vstore2((OUTPUT_TYPE2)(in0.w,in1.w),0,output_ptr+output_offset+stride+stride+stride);\n"
|
|
|
|
" }else if(remain == 1){\n"
|
|
|
|
" output_ptr[output_offset]=in0.x;\n"
|
|
|
|
" if(c+1 >= shape.y) return;\n"
|
|
|
|
" output_ptr[output_offset+stride]=in0.y;\n"
|
|
|
|
" if(c+2 >= shape.y) return;\n"
|
|
|
|
" output_ptr[output_offset+stride+stride]=in0.z;\n"
|
|
|
|
" if(c+3 >= shape.y) return;\n"
|
|
|
|
" output_ptr[output_offset+stride+stride+stride]=in0.w;\n"
|
|
|
|
" }\n"
|
|
|
|
" #elif OUTPUT_FORMAT == MNN_DATA_FORMAT_NHWC\n"
|
|
|
|
" int output_offset=((n*shape.z+h)*shape.w+w)*shape.y+c;\n"
|
|
|
|
" int remain=shape.y-c;\n"
|
|
|
|
" if(remain >= 4){\n"
|
|
|
|
" vstore4(CONVERT_OUTPUT4(in0),0,output_ptr+output_offset);\n"
|
|
|
|
" if(w+1 >= shape.w) return;\n"
|
|
|
|
" vstore4(CONVERT_OUTPUT4(in1),0,output_ptr+output_offset+shape.y);\n"
|
|
|
|
" if(w+2 >= shape.w) return;\n"
|
|
|
|
" vstore4(CONVERT_OUTPUT4(in2),0,output_ptr+output_offset+shape.y+shape.y);\n"
|
|
|
|
" if(w+3 >= shape.w) return;\n"
|
|
|
|
" vstore4(CONVERT_OUTPUT4(in3),0,output_ptr+output_offset+shape.y+shape.y+shape.y);\n"
|
|
|
|
" } else if(remain == 3){\n"
|
|
|
|
" vstore3((OUTPUT_TYPE3)(in0.x,in0.y,in0.z),0,output_ptr+output_offset);\n"
|
|
|
|
" if(w+1 >= shape.w) return;\n"
|
|
|
|
" vstore3((OUTPUT_TYPE3)(in1.x,in1.y,in1.z),0,output_ptr+output_offset+shape.y);\n"
|
|
|
|
" if(w+2 >= shape.w) return;\n"
|
|
|
|
" vstore3((OUTPUT_TYPE3)(in2.x,in2.y,in2.z),0,output_ptr+output_offset+shape.y+shape.y);\n"
|
|
|
|
" if(w+3 >= shape.w) return;\n"
|
|
|
|
" vstore3((OUTPUT_TYPE3)(in3.x,in3.y,in3.z),0,output_ptr+output_offset+shape.y+shape.y+shape.y);\n"
|
|
|
|
" } else if(remain == 2){\n"
|
|
|
|
" vstore2((OUTPUT_TYPE2)(in0.x,in0.y),0,output_ptr+output_offset);\n"
|
|
|
|
" if(w+1 >= shape.w) return;\n"
|
|
|
|
" vstore2((OUTPUT_TYPE2)(in1.x,in1.y),0,output_ptr+output_offset+shape.y);\n"
|
|
|
|
" if(w+2 >= shape.w) return;\n"
|
|
|
|
" vstore2((OUTPUT_TYPE2)(in2.x,in2.y),0,output_ptr+output_offset+shape.y+shape.y);\n"
|
|
|
|
" if(w+3 >= shape.w) return;\n"
|
|
|
|
" vstore2((OUTPUT_TYPE2)(in3.x,in3.y),0,output_ptr+output_offset+shape.y+shape.y+shape.y);\n"
|
|
|
|
" }else if(remain == 1){\n"
|
|
|
|
" output_ptr[output_offset]=in0.x;\n"
|
|
|
|
" if(w+1 >= shape.w) return;\n"
|
|
|
|
" output_ptr[output_offset+shape.y]=in1.x;\n"
|
|
|
|
" if(w+2 >= shape.w) return;\n"
|
|
|
|
" output_ptr[output_offset+shape.y+shape.y]=in1.x;\n"
|
|
|
|
" if(w+3 >= shape.w) return;\n"
|
|
|
|
" output_ptr[output_offset+shape.y+shape.y+shape.y]=in1.x;\n"
|
|
|
|
" }\n"
|
|
|
|
" #elif OUTPUT_FORMAT == MNN_DATA_FORMAT_NC4HW4\n"
|
|
|
|
" int output_offset=(((cblock*shape.x+n)*shape.z+h)*shape.w+w)*4;\n"
|
|
|
|
" vstore4(in0,0,output_ptr+output_offset);\n"
|
|
|
|
" if(w+1 >= shape.w) return;\n"
|
|
|
|
" vstore4(in1,0,output_ptr+output_offset+4);\n"
|
|
|
|
" if(w+2 >= shape.w) return;\n"
|
|
|
|
" vstore4(in2,0,output_ptr+output_offset+8);\n"
|
|
|
|
" if(w+3 >= shape.w) return;\n"
|
|
|
|
" vstore4(in3,0,output_ptr+output_offset+12);\n"
|
2025-06-05 15:15:29 +08:00
|
|
|
" #else\n"
|
|
|
|
" //not support\n"
|
2025-04-28 11:38:44 +08:00
|
|
|
" #endif\n"
|
|
|
|
"#endif\n"
|
|
|
|
"}\n"
|
|
|
|
"#endif\n"
|
|
|
|
"#ifdef CL_TO_SHARED\n"
|
|
|
|
"__kernel void cl_to_gl(GLOBAL_SIZE_3_DIMS\n"
|
|
|
|
" #ifdef USE_IMAGE\n"
|
|
|
|
" __read_only image2d_t input_ptr,\n"
|
|
|
|
" #else\n"
|
|
|
|
" __global INPUT_TYPE *input_ptr,\n"
|
|
|
|
" #endif\n"
|
|
|
|
" __global uchar *output_ptr,\n"
|
|
|
|
" __private const int4 shape // N C H W\n"
|
|
|
|
") {\n"
|
|
|
|
" int wblock=get_global_id(0);\n"
|
|
|
|
" int cblock=get_global_id(1);\n"
|
|
|
|
" int nh=get_global_id(2);\n"
|
|
|
|
" DEAL_NON_UNIFORM_DIM3(wblock,cblock,nh);\n"
|
|
|
|
" const int w=wblock << 2;\n"
|
|
|
|
" const int h=nh % shape.z;\n"
|
|
|
|
" const int c=cblock << 2;\n"
|
|
|
|
" const int n=nh/shape.z;\n"
|
|
|
|
" \n"
|
|
|
|
" int idx=c*shape.w+w; // c/4*w\n"
|
|
|
|
" int idy=nh; // n*h\n"
|
2025-06-05 15:15:29 +08:00
|
|
|
" INPUT_TYPE4 in0,in1,in2,in3;\n"
|
2025-04-28 11:38:44 +08:00
|
|
|
"#ifdef USE_IMAGE\n"
|
2025-06-05 15:15:29 +08:00
|
|
|
" in0=RI_DATA(input_ptr,SAMPLER,(int2)(idx,idy));\n"
|
|
|
|
" in1=RI_DATA(input_ptr,SAMPLER,(int2)(idx+1,idy));\n"
|
|
|
|
" in2=RI_DATA(input_ptr,SAMPLER,(int2)(idx+2,idy));\n"
|
|
|
|
" in3=RI_DATA(input_ptr,SAMPLER,(int2)(idx+3,idy));\n"
|
2025-04-28 11:38:44 +08:00
|
|
|
"#else\n"
|
|
|
|
" #if INPUT_FORMAT == MNN_DATA_FORMAT_NCHW\n"
|
|
|
|
" int input_offset=((n*shape.y+c)*shape.z+h)*shape.w+w;\n"
|
|
|
|
" int stride=shape.z*shape.w;\n"
|
|
|
|
" INPUT_TYPE4 tmp0,tmp1,tmp2,tmp3;\n"
|
|
|
|
" tmp0=vload4(0,input_ptr+input_offset);\n"
|
|
|
|
" tmp1=vload4(0,input_ptr+input_offset+stride);\n"
|
|
|
|
" tmp2=vload4(0,input_ptr+input_offset+stride+stride);\n"
|
|
|
|
" tmp3=vload4(0,input_ptr+input_offset+stride+stride+stride);\n"
|
2025-06-05 15:15:29 +08:00
|
|
|
" in0=(INPUT_TYPE4)(tmp0.x,tmp1.x,tmp2.x,tmp3.x);\n"
|
|
|
|
" in1=(INPUT_TYPE4)(tmp0.y,tmp1.y,tmp2.y,tmp3.y);\n"
|
|
|
|
" in2=(INPUT_TYPE4)(tmp0.z,tmp1.z,tmp2.z,tmp3.z);\n"
|
|
|
|
" in3=(INPUT_TYPE4)(tmp0.w,tmp1.w,tmp2.w,tmp3.w);\n"
|
2025-04-28 11:38:44 +08:00
|
|
|
" #elif INPUT_FORMAT == MNN_DATA_FORMAT_NHWC\n"
|
|
|
|
" int input_offset=((n*shape.z+h)*shape.w+w)*shape.y+c;\n"
|
2025-06-05 15:15:29 +08:00
|
|
|
" in0=vload4(0,input_ptr+input_offset);\n"
|
|
|
|
" in1=vload4(0,input_ptr+input_offset+shape.y);\n"
|
|
|
|
" in2=vload4(0,input_ptr+input_offset+shape.y+shape.y);\n"
|
|
|
|
" in3=vload4(0,input_ptr+input_offset+shape.y+shape.y+shape.y);\n"
|
2025-04-28 11:38:44 +08:00
|
|
|
" #elif INPUT_FORMAT == MNN_DATA_FORMAT_NC4HW4\n"
|
|
|
|
" int input_offset=(((cblock*shape.x+n)*shape.z+h)*shape.w+w)*4;\n"
|
2025-06-05 15:15:29 +08:00
|
|
|
" in0=vload4(0,input_ptr+input_offset);\n"
|
|
|
|
" in1=vload4(0,input_ptr+input_offset+4);\n"
|
|
|
|
" in2=vload4(0,input_ptr+input_offset+8);\n"
|
|
|
|
" in3=vload4(0,input_ptr+input_offset+12);\n"
|
|
|
|
" #else\n"
|
|
|
|
" // not support\n"
|
2025-04-28 11:38:44 +08:00
|
|
|
" #endif\n"
|
|
|
|
"#endif\n"
|
|
|
|
" const int offset=idy*shape.w*4;\n"
|
|
|
|
" vstore4(convert_uchar4(in0),idx,output_ptr+offset);\n"
|
|
|
|
" if(w+1 >= shape.w) return;\n"
|
|
|
|
" vstore4(convert_uchar4(in1),idx+1,output_ptr+offset);\n"
|
|
|
|
" if(w+2 >= shape.w) return;\n"
|
|
|
|
" vstore4(convert_uchar4(in2),idx+2,output_ptr+offset);\n"
|
|
|
|
" if(w+3 >= shape.w) return;\n"
|
|
|
|
" vstore4(convert_uchar4(in3),idx+3,output_ptr+offset);\n"
|
|
|
|
"}\n"
|
|
|
|
"#endif\n"
|
|
|
|
;
|
|
|
|
}
|