MNN/source/backend/opencl/execution/cl/gather_buf_mnn_cl.cpp

51 lines
1.5 KiB
C++

#include "opencl_source_map.hpp"
namespace MNN {
#ifndef MNN_OPENCL_BUFFER_CLOSED
const char* gather_buf =
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"__kernel void batch_gather_buf(__private int global_dim0,__private int global_dim1,__private int global_dim2,\n"
" __global OUTPUT_TYPE* output,__global INPUT_TYPE* input,\n"
" #ifdef OFFSET_DST\n"
" __global int* offset_dst_ptr,\n"
" #endif\n"
" #ifdef OFFSET_SRC\n"
" __global int* offset_src_ptr,\n"
" #endif\n"
" __private const int x_size,\n"
" __private const int4 stride_src,\n"
" __private const int4 stride_dst,\n"
" __private const int2 steps,\n"
" __private const int2 iters,\n"
" __private const int inputSize) {\n"
" int3 pos=(int3)(get_global_id(0),get_global_id(1),get_global_id(2));\n"
" \n"
" if (pos.x<global_dim0 && pos.y<global_dim1 && pos.z<global_dim2) {\n"
" \n"
" int x=pos.x % x_size;\n"
" int y=pos.x/x_size;\n"
" int2 index=(int2)(pos.z,pos.z);\n"
"#ifdef OFFSET_DST\n"
" index.x=offset_dst_ptr[pos.z];\n"
"#endif\n"
" \n"
"#ifdef OFFSET_SRC\n"
" index.y=offset_src_ptr[pos.z];\n"
"#endif\n"
" int2 offset=index*steps;\n"
" int src_offset=offset.y+stride_src.w+x*stride_src.x+y*stride_src.y+pos.y*stride_src.z;\n"
" int dst_offset=offset.x+stride_dst.w+x*stride_dst.x+y*stride_dst.y+pos.y*stride_dst.z;\n"
" if(offset.x >= 0){\n"
" if(offset.y >= 0 && offset.y<inputSize){\n"
" output[dst_offset]=(OUTPUT_TYPE)input[src_offset];\n"
" }else{\n"
" output[dst_offset]=(OUTPUT_TYPE)(0);\n"
" }\n"
" }\n"
" }\n"
"}\n"
;
#endif
}