MNN/source/backend/vulkan/buffer/execution/glsl/softmaxHeight_NHWC.comp

48 lines
1.2 KiB
Plaintext

#version 440 core
layout(std430) buffer;
layout(set=0, binding=0) buffer destbuffer{
float data[];
}uOutput;
layout(set=0, binding=1) readonly buffer sourceBuffer{
float data[];
}uInput;
layout(set = 0, binding = 2) uniform constBuffer {
int w;//inside
int h;//axis
int c;//outside
}uConst;
layout(local_size_x = 8, local_size_y = 8, local_size_z = 1) in;
void main()
{
// input tensor's layout is NHWC
ivec3 pos = ivec3(gl_GlobalInvocationID);
// x: index in outside, y: index in inside
if(pos.y < uConst.w && pos.x < uConst.c)
{
int W = uConst.w;
int H = uConst.h;
int C = uConst.c;
float maxValue = uInput.data[pos.x * H * W + pos.y];
for(int i = 1; i < H; ++i)
{
int index = i * W + pos.x * H * W + pos.y;
maxValue = max(maxValue, uInput.data[index]);
}
float sum = 0.0;
for(int i = 0; i < H; ++i)
{
int index = i * W + pos.x * H * W + pos.y;
sum += exp(uInput.data[index] - maxValue);
}
for(int i = 0; i < H; ++i)
{
int index = i * W + pos.x * H * W + pos.y;
uOutput.data[index] = exp(uInput.data[index] - maxValue) / sum;
}
}
}