MNN/source/backend/vulkan/execution/glsl/winogradTransformSource2_3_...

86 lines
3.5 KiB
Plaintext

#version 450 core
layout(std430) buffer;
layout(std430) uniform;
layout(set=0, binding=0, rgba16f) writeonly restrict uniform image2D uOutput;
layout(set=0, binding=1) uniform sampler3D uInput;
layout(set=0, binding=2) readonly restrict uniform constBuffer {
ivec4 inputSize;
ivec4 outputSize;
int padX;
int padY;
int unitWidth;
int unitHeight;
int unit;
} uConst;
layout(set=0, binding=3) readonly restrict uniform offsetBuffer {
ivec2 offset;
} uOffset;
layout(local_size_x = 8, local_size_y = 8, local_size_z = 1) in;
void main()
{
ivec3 pos = ivec3(gl_GlobalInvocationID);
ivec2 realPos = pos.xy + uOffset.offset;
if (pos.x < uConst.unitWidth && pos.y < uConst.unitHeight)
{
int dstXOrigin = pos.z;
int dstYOrigin = uConst.unitWidth * pos.y + pos.x;
int srcHeight = (uConst.unitWidth*uConst.unitHeight+3)/4;
int dstY = dstYOrigin / 4;
int dstX = dstYOrigin % 4 + 4*dstXOrigin;
int sxStart = (realPos.x)*uConst.unit - uConst.padX;
int syStart = (realPos.y)*uConst.unit - uConst.padY;
{
vec4 S00= texelFetch(uInput, ivec3(sxStart+0, syStart+ 0, pos.z), 0);
vec4 S10= texelFetch(uInput, ivec3(sxStart+1, syStart+ 0, pos.z), 0);
vec4 S20= texelFetch(uInput, ivec3(sxStart+2, syStart+ 0, pos.z), 0);
vec4 S30= texelFetch(uInput, ivec3(sxStart+3, syStart+ 0, pos.z), 0);
vec4 S01= texelFetch(uInput, ivec3(sxStart+0, syStart+ 1, pos.z), 0);
vec4 S11= texelFetch(uInput, ivec3(sxStart+1, syStart+ 1, pos.z), 0);
vec4 S21= texelFetch(uInput, ivec3(sxStart+2, syStart+ 1, pos.z), 0);
vec4 S31= texelFetch(uInput, ivec3(sxStart+3, syStart+ 1, pos.z), 0);
vec4 S02= texelFetch(uInput, ivec3(sxStart+0, syStart+ 2, pos.z), 0);
vec4 S12= texelFetch(uInput, ivec3(sxStart+1, syStart+ 2, pos.z), 0);
vec4 S22= texelFetch(uInput, ivec3(sxStart+2, syStart+ 2, pos.z), 0);
vec4 S32= texelFetch(uInput, ivec3(sxStart+3, syStart+ 2, pos.z), 0);
vec4 S03= texelFetch(uInput, ivec3(sxStart+0, syStart+ 3, pos.z), 0);
vec4 S13= texelFetch(uInput, ivec3(sxStart+1, syStart+ 3, pos.z), 0);
vec4 S23= texelFetch(uInput, ivec3(sxStart+2, syStart+ 3, pos.z), 0);
vec4 S33= texelFetch(uInput, ivec3(sxStart+3, syStart+ 3, pos.z), 0);
vec4 m00= +S00-S02;
vec4 m10= +S10-S12;
vec4 m20= +S20-S22;
vec4 m30= +S30-S32;
vec4 m01= +0.5*S01+0.5*S02;
vec4 m11= +0.5*S11+0.5*S12;
vec4 m21= +0.5*S21+0.5*S22;
vec4 m31= +0.5*S31+0.5*S32;
vec4 m02= -0.5*S01+0.5*S02;
vec4 m12= -0.5*S11+0.5*S12;
vec4 m22= -0.5*S21+0.5*S22;
vec4 m32= -0.5*S31+0.5*S32;
vec4 m03= -S01+S03;
vec4 m13= -S11+S13;
vec4 m23= -S21+S23;
vec4 m33= -S31+S33;
imageStore(uOutput, ivec2(dstX, dstY+srcHeight*0), +m00-m20);
imageStore(uOutput, ivec2(dstX, dstY+srcHeight*1), +0.5*m10+0.5*m20);
imageStore(uOutput, ivec2(dstX, dstY+srcHeight*2), -0.5*m10+0.5*m20);
imageStore(uOutput, ivec2(dstX, dstY+srcHeight*3), -m10+m30);
imageStore(uOutput, ivec2(dstX, dstY+srcHeight*4), +m01-m21);
imageStore(uOutput, ivec2(dstX, dstY+srcHeight*5), +0.5*m11+0.5*m21);
imageStore(uOutput, ivec2(dstX, dstY+srcHeight*6), -0.5*m11+0.5*m21);
imageStore(uOutput, ivec2(dstX, dstY+srcHeight*7), -m11+m31);
imageStore(uOutput, ivec2(dstX, dstY+srcHeight*8), +m02-m22);
imageStore(uOutput, ivec2(dstX, dstY+srcHeight*9), +0.5*m12+0.5*m22);
imageStore(uOutput, ivec2(dstX, dstY+srcHeight*10), -0.5*m12+0.5*m22);
imageStore(uOutput, ivec2(dstX, dstY+srcHeight*11), -m12+m32);
imageStore(uOutput, ivec2(dstX, dstY+srcHeight*12), +m03-m23);
imageStore(uOutput, ivec2(dstX, dstY+srcHeight*13), +0.5*m13+0.5*m23);
imageStore(uOutput, ivec2(dstX, dstY+srcHeight*14), -0.5*m13+0.5*m23);
imageStore(uOutput, ivec2(dstX, dstY+srcHeight*15), -m13+m33);
}
}
}