MNN/source/backend/vulkan/execution/glsl/lstmSave.comp

47 lines
1.1 KiB
Plaintext

#version 440 core
layout(set = 0, binding = 0, std430) readonly buffer srcBuffer{
float data[];
}uInput;
layout(set = 0, binding = 1, rgba16f) writeonly restrict uniform image3D uOutput;
layout(set = 0, binding = 2) uniform constBuffer{
ivec4 info; // ow, oz, 0, 0
}uLSTM;
layout(local_size_x = 16) in;
void main()
{
ivec3 pos = ivec3(gl_GlobalInvocationID);
if(pos.x >= uLSTM.info.x || pos.z >= uLSTM.info.y) return;
int ow = uLSTM.info.x;
int inputIndex = pos.x + pos.z * 4 * ow;
int lastZ = uLSTM.info.y / 4;
int lastChannel = uLSTM.info.y % 4;
vec4 temp = vec4(0.0);
if(lastZ == pos.z)
{
if(lastChannel == 1)
{
temp = vec4(uInput.data[inputIndex], 0, 0, 0);
}
else if(lastChannel == 2)
{
temp = vec4(uInput.data[inputIndex], uInput.data[inputIndex + ow], 0, 0);
}
else
{
temp = vec4(uInput.data[inputIndex], uInput.data[inputIndex + ow], uInput.data[inputIndex + 2*ow], 0);
}
}
else
{
temp = vec4(uInput.data[inputIndex], uInput.data[inputIndex + ow], uInput.data[inputIndex + 2*ow], uInput.data[inputIndex + 3*ow]);
}
imageStore(uOutput, pos, temp);
}