MNN/source/backend/vulkan/execution/glsl/softmax.comp

99 lines
2.4 KiB
Plaintext

#version 310 es
#define PRECISION highp
precision PRECISION float;
layout(std140) buffer;
layout(set=0, binding=0) buffer destbuffer{
vec4 data[];
}uOutput;
layout(set=0, binding=1) readonly buffer sourceBuffer{
vec4 data[];
}uInput;
layout(set=0, binding=2) uniform constBuffer {
int outputSize;
int oneInvocationLen;
}uConst;
layout(local_size_x = 64) in;
shared vec4 sharedMaxValue[64];
shared vec4 sumValue;
shared vec4 maxValue;
shared float trueMaxValue;
shared float sum;
void main()
{
ivec3 pos = ivec3(gl_GlobalInvocationID);
// get the max value in one invocation
int localIndex = ivec3(gl_LocalInvocationID).x;
int dataIndex = pos.x * uConst.oneInvocationLen;
int invocationDataLen = min(dataIndex + uConst.oneInvocationLen, uConst.outputSize);
sharedMaxValue[localIndex] = uInput.data[0 + pos.x * uConst.oneInvocationLen];
for(int i = dataIndex; i < invocationDataLen; ++i)
{
vec4 tempValue = uInput.data[0 + i];
sharedMaxValue[localIndex] = max(sharedMaxValue[localIndex], tempValue);
}
memoryBarrierShared();
// get the max value whthin the work-group
if(localIndex == 0)
{
sumValue = vec4(0.0);
maxValue = vec4(-100000.0);
for(int i = 0; i < 64; ++i)
{
maxValue = max(maxValue, sharedMaxValue[i]);
}
trueMaxValue = maxValue.x;
trueMaxValue = max(trueMaxValue, maxValue.y);
trueMaxValue = max(trueMaxValue, maxValue.z);
trueMaxValue = max(trueMaxValue, maxValue.w);
maxValue = vec4(trueMaxValue, trueMaxValue, trueMaxValue, trueMaxValue);
}
barrier();
// exp
for(int i = dataIndex; i < invocationDataLen; ++i)
{
uOutput.data[i] = exp(uInput.data[i] - maxValue);
}
// sum
// first sum in one invocation
sharedMaxValue[localIndex] = vec4(0.0);
for(int i = dataIndex; i < invocationDataLen; ++i)
{
sharedMaxValue[localIndex] += uOutput.data[i];
}
memoryBarrierShared();
// second sum in the work-group
if(localIndex == 0)
{
sum = 0.0;
for(int i = 0; i < 64; ++i)
{
sumValue += sharedMaxValue[i];
}
sum += (sumValue.x + sumValue.y + sumValue.z + sumValue.w);
sumValue = vec4(sum, sum, sum, sum);
}
barrier();
for(int i = dataIndex; i < invocationDataLen; ++i)
{
uOutput.data[i] /= sumValue;
}
}