MNN/source/backend/vulkan/execution/glsl/lstmGate.comp

49 lines
1.3 KiB
Plaintext

#version 440 core
layout(set = 0, binding = 0) writeonly buffer destBuffer{
mat4 data[];
}uOutput;
layout(set = 0, binding = 1) uniform mediump sampler3D uInput;
layout(binding = 2, std430) readonly buffer weightBuffer{
float data[];
}uWeight;
layout(binding = 3) uniform constBuffer{
ivec4 shape; // ow, iw, channel, 0
}uLSTMGateParam;
layout(local_size_x = 8, local_size_y = 1, local_size_z = 1) in;
void main()
{
// the layout of input is NC4HW4
ivec3 pos = ivec3(gl_GlobalInvocationID);
int ow = uLSTMGateParam.shape.x;
int iw = uLSTMGateParam.shape.y;
int iz = uLSTMGateParam.shape.z;
if(pos.x >= ow || pos.z >= iz) return;
int weightStride = ow * iw;
int wiIndex = pos.x * iw;
int wfIndex = wiIndex + weightStride;
int woIndex = wfIndex + weightStride;
int wgIndex = woIndex + weightStride;
vec4 temp0 = vec4(0.0);
vec4 temp1 = vec4(0.0);
vec4 temp2 = vec4(0.0);
vec4 temp3 = vec4(0.0);
for(int i = 0; i < iw; ++i)
{
vec4 color = texelFetch(uInput, ivec3(i, pos.y, pos.z), 0);
temp0 += color * vec4(uWeight.data[wiIndex + i]);
temp1 += color * vec4(uWeight.data[wfIndex + i]);
temp2 += color * vec4(uWeight.data[woIndex + i]);
temp3 += color * vec4(uWeight.data[wgIndex + i]);
}
int index = pos.x + pos.z * ow;
uOutput.data[index] = mat4(temp0, temp1, temp2, temp3);
}