MNN/source/backend/opencl/execution/cl/layernorm_buf.cl

142 lines
4.8 KiB
Common Lisp
Raw Normal View History

2023-07-18 09:36:26 +08:00
#ifdef MNN_SUPPORT_FP16
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#endif
2024-09-12 12:57:57 +08:00
__kernel void layernorm_buf(__private int global_dim0, __private int global_dim1,
2024-05-11 19:17:02 +08:00
__global const FLOAT * input,
__global FLOAT * output,
__private const int inside,
#ifdef GAMMA_BETA
__global const FLOAT *gamma,
__global const FLOAT *beta,
#endif
__private float epsilon){
2024-09-12 12:57:57 +08:00
int2 pos = (int2)(get_global_id(0), get_global_id(1));
#if LOCAL_SIZE > 1
float local sum_mnn[LOCAL_SIZE];
#ifndef RMSNORM
float local sum_mean_mnn[LOCAL_SIZE];
#endif
2024-09-12 12:57:57 +08:00
if (pos.x < global_dim0 && pos.y < global_dim1) {
2024-05-11 19:17:02 +08:00
const int lid = get_local_id(0);
2024-09-12 12:57:57 +08:00
const int offset = pos.y * inside;
2024-05-11 19:17:02 +08:00
const int inside_v4 = (inside + 3) >> 2;
2024-09-12 12:57:57 +08:00
#ifdef PACK_LEAVE
const int loop = inside_v4 - 1;
2024-05-11 19:17:02 +08:00
const int inside_remain = inside - ((inside_v4-1) << 2);
2024-09-12 12:57:57 +08:00
#else
const int loop = inside_v4;
#endif
float4 in_sum = 0;
2024-07-04 11:53:45 +08:00
int index = lid;
2024-09-12 12:57:57 +08:00
#ifdef RMSNORM
float4 mean = (float4)0;
#else
for(; index < loop; index+=LOCAL_SIZE){
float4 in = convert_float4(vload4(index, input + offset));
2024-05-11 19:17:02 +08:00
in_sum += in;
}
sum_mean_mnn[lid] = in_sum.x + in_sum.y + in_sum.z+ in_sum.w;
2024-05-11 19:17:02 +08:00
2024-09-12 12:57:57 +08:00
#ifdef PACK_LEAVE
2024-07-04 11:53:45 +08:00
if(index == inside_v4 - 1) {
2025-06-05 15:15:29 +08:00
for(int i = 0; i < inside_remain; ++i){
2024-09-12 12:57:57 +08:00
float in = input[offset + index * 4 + i];
sum_mean_mnn[lid] = sum_mean_mnn[lid] + in;
2024-05-11 19:17:02 +08:00
}
}
2024-09-12 12:57:57 +08:00
#endif
2024-05-11 19:17:02 +08:00
barrier(CLK_LOCAL_MEM_FENCE);
for(int i = LOCAL_SIZE/2; i > 0; i /= 2){
if (lid < i)
sum_mean_mnn[lid] = sum_mean_mnn[lid] + sum_mean_mnn[lid + i];
2024-05-11 19:17:02 +08:00
barrier(CLK_LOCAL_MEM_FENCE);
}
float4 mean = sum_mean_mnn[0] / (float4)inside;
2024-09-12 12:57:57 +08:00
#endif
2024-05-11 19:17:02 +08:00
in_sum = 0;
2024-07-04 11:53:45 +08:00
index = lid;
2024-09-12 12:57:57 +08:00
for(; index < loop; index+=LOCAL_SIZE){
float4 in = convert_float4(vload4(index, input + offset));
2024-05-11 19:17:02 +08:00
in_sum += (in - mean) * (in - mean);
}
sum_mnn[lid] = in_sum.x + in_sum.y + in_sum.z + in_sum.w;
2024-09-12 12:57:57 +08:00
#ifdef PACK_LEAVE
2024-07-04 11:53:45 +08:00
if(index == inside_v4 - 1) {
2025-06-05 15:15:29 +08:00
for(int i = 0; i < inside_remain; ++i) {
2024-09-12 12:57:57 +08:00
float in = input[offset + index * 4 + i];
2025-06-05 15:15:29 +08:00
in = (in - mean.x) * (in - mean.x);
sum_mnn[lid] = sum_mnn[lid] + in;
2024-05-11 19:17:02 +08:00
}
}
2024-09-12 12:57:57 +08:00
#endif
2024-05-11 19:17:02 +08:00
barrier(CLK_LOCAL_MEM_FENCE);
for(int i = LOCAL_SIZE/2; i > 0; i /= 2){
if (lid < i)
sum_mnn[lid] = sum_mnn[lid] + sum_mnn[lid + i];
2024-05-11 19:17:02 +08:00
barrier(CLK_LOCAL_MEM_FENCE);
}
float4 square_sum = sum_mnn[0] / (float4)inside;
2024-09-12 12:57:57 +08:00
float4 value = (float4)1.0f / (float4)sqrt(square_sum + (float4)epsilon);
index = lid;
for(; index < loop; index+=LOCAL_SIZE){
float4 in = convert_float4(vload4(index, input + offset));
#ifdef GAMMA_BETA
float4 out = (in - mean) * value * convert_float4(vload4(index, gamma)) + convert_float4(vload4(index, beta));
#else
float4 out = (in - mean) * value;
#endif
vstore4(CONVERT_FLOAT4(out), index, output + offset);
}
#ifdef PACK_LEAVE
if(index == inside_v4 - 1) {
for(int i = 0; i < inside_remain; ++i){
float in = input[offset + index * 4 + i];
#ifdef GAMMA_BETA
float out = (in - mean.x) * value.x * (float)gamma[index * 4 + i] + (float)beta[index * 4 + i];
#else
float out = (in - mean.x) * value.x;
#endif
output[offset + index * 4 + i] = out;
}
}
#endif
}
2024-05-11 19:17:02 +08:00
#else
2024-09-12 12:57:57 +08:00
if (pos.x < global_dim0 && pos.y < global_dim1) {
const int offset = pos.y * inside;
float in_sum = 0;
2024-09-12 12:57:57 +08:00
#ifdef RMSNORM
float mean = 0;
#else
for(int index = 0; index < inside; index++){
in_sum += (float)input[offset + index];
}
float mean = in_sum / inside;
#endif
in_sum = 0;
for(int index = 0; index < inside; index++){
float in = (float)input[offset + index];
in_sum += (in - mean) * (in - mean);
}
float square_sum = in_sum / inside;
float value = 1.0f / sqrt(square_sum + epsilon);
for(int i = 0; i < inside; ++i){
float in = input[offset + i];
#ifdef GAMMA_BETA
float out = (in - mean) * value * (float)gamma[i] + (float)beta[i];
#else
float out = (in - mean) * value;
#endif
output[offset + i] = out;
2024-05-11 19:17:02 +08:00
}
}
2024-09-12 12:57:57 +08:00
#endif
2024-05-11 19:17:02 +08:00
}