mirror of https://github.com/alibaba/MNN.git
99 lines
2.4 KiB
Plaintext
99 lines
2.4 KiB
Plaintext
#version 310 es
|
|
#define PRECISION highp
|
|
precision PRECISION float;
|
|
layout(std140) buffer;
|
|
layout(set=0, binding=0) buffer destbuffer{
|
|
vec4 data[];
|
|
}uOutput;
|
|
|
|
layout(set=0, binding=1) readonly buffer sourceBuffer{
|
|
vec4 data[];
|
|
}uInput;
|
|
|
|
layout(set=0, binding=2) uniform constBuffer {
|
|
int outputSize;
|
|
int oneInvocationLen;
|
|
}uConst;
|
|
|
|
layout(local_size_x = 64) in;
|
|
|
|
shared vec4 sharedMaxValue[64];
|
|
shared vec4 sumValue;
|
|
shared vec4 maxValue;
|
|
shared float trueMaxValue;
|
|
shared float sum;
|
|
|
|
void main()
|
|
{
|
|
ivec3 pos = ivec3(gl_GlobalInvocationID);
|
|
|
|
// get the max value in one invocation
|
|
int localIndex = ivec3(gl_LocalInvocationID).x;
|
|
int dataIndex = pos.x * uConst.oneInvocationLen;
|
|
int invocationDataLen = min(dataIndex + uConst.oneInvocationLen, uConst.outputSize);
|
|
sharedMaxValue[localIndex] = uInput.data[0 + pos.x * uConst.oneInvocationLen];
|
|
for(int i = dataIndex; i < invocationDataLen; ++i)
|
|
{
|
|
vec4 tempValue = uInput.data[0 + i];
|
|
sharedMaxValue[localIndex] = max(sharedMaxValue[localIndex], tempValue);
|
|
}
|
|
|
|
memoryBarrierShared();
|
|
|
|
// get the max value whthin the work-group
|
|
if(localIndex == 0)
|
|
{
|
|
sumValue = vec4(0.0);
|
|
maxValue = vec4(-100000.0);
|
|
for(int i = 0; i < 64; ++i)
|
|
{
|
|
maxValue = max(maxValue, sharedMaxValue[i]);
|
|
}
|
|
|
|
trueMaxValue = maxValue.x;
|
|
trueMaxValue = max(trueMaxValue, maxValue.y);
|
|
trueMaxValue = max(trueMaxValue, maxValue.z);
|
|
trueMaxValue = max(trueMaxValue, maxValue.w);
|
|
maxValue = vec4(trueMaxValue, trueMaxValue, trueMaxValue, trueMaxValue);
|
|
}
|
|
|
|
barrier();
|
|
|
|
// exp
|
|
for(int i = dataIndex; i < invocationDataLen; ++i)
|
|
{
|
|
uOutput.data[i] = exp(uInput.data[i] - maxValue);
|
|
}
|
|
|
|
// sum
|
|
// first sum in one invocation
|
|
sharedMaxValue[localIndex] = vec4(0.0);
|
|
for(int i = dataIndex; i < invocationDataLen; ++i)
|
|
{
|
|
sharedMaxValue[localIndex] += uOutput.data[i];
|
|
}
|
|
|
|
memoryBarrierShared();
|
|
|
|
// second sum in the work-group
|
|
if(localIndex == 0)
|
|
{
|
|
sum = 0.0;
|
|
for(int i = 0; i < 64; ++i)
|
|
{
|
|
sumValue += sharedMaxValue[i];
|
|
}
|
|
sum += (sumValue.x + sumValue.y + sumValue.z + sumValue.w);
|
|
|
|
sumValue = vec4(sum, sum, sum, sum);
|
|
}
|
|
|
|
barrier();
|
|
|
|
for(int i = dataIndex; i < invocationDataLen; ++i)
|
|
{
|
|
uOutput.data[i] /= sumValue;
|
|
}
|
|
|
|
}
|