MNN/source/backend/vulkan/execution/glsl/reduce.comp

59 lines
1.3 KiB
Plaintext

#version 440 core
layout(std430) buffer;
layout(set=0, binding=1) writeonly buffer destbuffer{
float data[];
}uOutput;
layout(set=0, binding=0) readonly buffer sourceBuffer{
float data[];
}uInput;
layout(set = 0, binding = 2) uniform constBuffer {
int w;//inside
int h;//axis
int c;//outside
float k;//For mean
}uConst;
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
void main()
{
ivec3 posTmp = ivec3(gl_GlobalInvocationID);
ivec2 pos;
pos.x = posTmp.x / uConst.w;
pos.y = posTmp.x % uConst.w;
// x: index in outside, y: index in inside
if(pos.y < uConst.w && pos.x < uConst.c)
{
int W = uConst.w;
int H = uConst.h;
int C = uConst.c;
int startIndex = pos.x * H * W + pos.y;
float res = uInput.data[startIndex];
for(int i = 1; i < H; ++i)
{
float next = uInput.data[startIndex + i * W];
#ifdef VMAX
res = max(res, next);
#endif
#ifdef VMIN
res = min(res, next);
#endif
#ifdef SUM
res = res + next;
#endif
#ifdef PROD
res = res * next;
#endif
#ifdef MEAN
res = res + next;
#endif
}
#ifdef MEAN
res = res * uConst.k;
#endif
uOutput.data[posTmp.x] = res;
}
}