mirror of https://github.com/alibaba/MNN.git
49 lines
1.3 KiB
Plaintext
49 lines
1.3 KiB
Plaintext
#version 440 core
|
|
|
|
layout(set = 0, binding = 0) writeonly buffer destBuffer{
|
|
mat4 data[];
|
|
}uOutput;
|
|
|
|
layout(set = 0, binding = 1) uniform mediump sampler3D uInput;
|
|
|
|
layout(binding = 2, std430) readonly buffer weightBuffer{
|
|
float data[];
|
|
}uWeight;
|
|
|
|
layout(binding = 3) uniform constBuffer{
|
|
ivec4 shape; // ow, iw, channel, 0
|
|
}uLSTMGateParam;
|
|
|
|
layout(local_size_x = 8, local_size_y = 1, local_size_z = 1) in;
|
|
|
|
void main()
|
|
{
|
|
// the layout of input is NC4HW4
|
|
ivec3 pos = ivec3(gl_GlobalInvocationID);
|
|
int ow = uLSTMGateParam.shape.x;
|
|
int iw = uLSTMGateParam.shape.y;
|
|
int iz = uLSTMGateParam.shape.z;
|
|
if(pos.x >= ow || pos.z >= iz) return;
|
|
|
|
int weightStride = ow * iw;
|
|
|
|
int wiIndex = pos.x * iw;
|
|
int wfIndex = wiIndex + weightStride;
|
|
int woIndex = wfIndex + weightStride;
|
|
int wgIndex = woIndex + weightStride;
|
|
|
|
vec4 temp0 = vec4(0.0);
|
|
vec4 temp1 = vec4(0.0);
|
|
vec4 temp2 = vec4(0.0);
|
|
vec4 temp3 = vec4(0.0);
|
|
for(int i = 0; i < iw; ++i)
|
|
{
|
|
vec4 color = texelFetch(uInput, ivec3(i, pos.y, pos.z), 0);
|
|
temp0 += color * vec4(uWeight.data[wiIndex + i]);
|
|
temp1 += color * vec4(uWeight.data[wfIndex + i]);
|
|
temp2 += color * vec4(uWeight.data[woIndex + i]);
|
|
temp3 += color * vec4(uWeight.data[wgIndex + i]);
|
|
}
|
|
int index = pos.x + pos.z * ow;
|
|
uOutput.data[index] = mat4(temp0, temp1, temp2, temp3);
|
|
} |