MNN/source/backend/vulkan/execution/glsl/relu.comp

50 lines
1.3 KiB
Plaintext
Raw Normal View History

2019-04-17 10:49:11 +08:00
#version 450 core
layout(std430) buffer;
layout(std430) uniform;
#ifdef IMAGE
2019-04-17 10:49:11 +08:00
layout(set=0, binding=0, rgba16f) writeonly restrict mediump uniform image3D uOutput;
layout(set=0, binding=1) uniform mediump sampler3D uInput;
#else
layout(set = 0, binding = 0) writeonly buffer destBuffer{
vec4 data[];
}uOutput;
layout(set = 0, binding = 1) readonly buffer sourceBuffer{
vec4 data[];
}uInput;
#endif
2019-04-17 10:49:11 +08:00
layout(set = 0, binding = 2) uniform reluBuffer{
ivec4 imgSize;
vec4 slope;
2019-04-17 10:49:11 +08:00
}uReluParam;
#ifdef IMAGE
2019-04-17 10:49:11 +08:00
layout(local_size_x = 16, local_size_y = 16, local_size_z = 1) in;
#else
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
#endif
2019-04-17 10:49:11 +08:00
void main()
{
ivec3 pos = ivec3(gl_GlobalInvocationID);
ivec3 imgSize = uReluParam.imgSize.xyz;
#ifdef IMAGE
2019-04-17 10:49:11 +08:00
if(pos.x < imgSize.x && pos.y < imgSize.y)
{
vec4 dataIn = texelFetch(uInput, pos, 0);
bvec4 lessZero = bvec4(lessThan(dataIn, vec4(0.0)));
vec4 dataTemp = dataIn * uReluParam.slope;
2019-04-17 10:49:11 +08:00
imageStore(uOutput, pos, mix(dataIn, dataTemp, lessZero));
}
#else
if(pos.x < imgSize.x)
{
vec4 dataIn = uInput.data[pos.x];
bvec4 lessZero = bvec4(lessThan(dataIn, vec4(0.0)));
vec4 dataTemp = dataIn * uReluParam.slope;
uOutput.data[pos.x] = mix(dataIn, dataTemp, lessZero);
}
#endif
}