mirror of https://github.com/alibaba/MNN.git
				
				
				
			
		
			
	
	
		
			252 lines
		
	
	
		
			8.7 KiB
		
	
	
	
		
			C++
		
	
	
	
		
		
			
		
	
	
			252 lines
		
	
	
		
			8.7 KiB
		
	
	
	
		
			C++
		
	
	
	
|  | //
 | ||
|  | //  winogradGenerateGLSL.cpp
 | ||
|  | //  MNN
 | ||
|  | //
 | ||
|  | //  Created by MNN on 2019/01/22.
 | ||
|  | //  Copyright © 2018, Alibaba Group Holding Limited
 | ||
|  | //
 | ||
|  | 
 | ||
|  | #include <string.h>
 | ||
|  | #include <fstream>
 | ||
|  | #include <sstream>
 | ||
|  | #include <MNN/MNNDefine.h>
 | ||
|  | #include "math/Matrix.hpp"
 | ||
|  | #include "math/WingoradGenerater.hpp"
 | ||
|  | 
 | ||
|  | using namespace std; | ||
|  | 
 | ||
|  | const char* gWinogradSourceHead = | ||
|  |     "#version 450 core\n" | ||
|  |     "layout(std430) buffer;\n" | ||
|  |     "layout(std430) uniform;\n" | ||
|  |     "layout(set=0, binding=0, rgba16f) writeonly restrict uniform image2D uOutput;\n" | ||
|  |     "layout(set=0, binding=1) uniform sampler3D uInput;\n" | ||
|  |     "layout(set=0, binding=2) readonly restrict uniform constBuffer {\n" | ||
|  |     "    ivec4 inputSize;\n" | ||
|  |     "    ivec4 outputSize;\n" | ||
|  |     "    int padX;\n" | ||
|  |     "    int padY;\n" | ||
|  |     "    int unitWidth;\n" | ||
|  |     "    int unitHeight;\n" | ||
|  |     "    int unit;\n" | ||
|  |     "} uConst;\n" | ||
|  |     "layout(local_size_x = 8, local_size_y = 8, local_size_z = 1) in;\n" | ||
|  |     "void main()\n" | ||
|  |     "{\n" | ||
|  |     "    ivec3 pos = ivec3(gl_GlobalInvocationID);\n" | ||
|  |     "    if (pos.x < uConst.unitWidth && pos.y < uConst.unitHeight)\n" | ||
|  |     "    {\n" | ||
|  |     "int dstXOrigin = pos.z;\n" | ||
|  |     "int dstYOrigin = uConst.unitWidth * pos.y + pos.x;\n" | ||
|  |     "int srcHeight = (uConst.unitWidth*uConst.unitHeight+3)/4;\n" | ||
|  |     "int dstY = dstYOrigin / 4;\n" | ||
|  |     "int dstX = dstYOrigin % 4 + 4*dstXOrigin;\n" | ||
|  |     "        int sxStart = pos.x*uConst.unit - uConst.padX;\n" | ||
|  |     "        int syStart = pos.y*uConst.unit - uConst.padY;\n"; | ||
|  | 
 | ||
|  | const char* gWinogradSourceTail = | ||
|  |     "    }\n" | ||
|  |     "}\n"; | ||
|  | 
 | ||
|  | const char* gWinogradDestHead = | ||
|  |     "#version 450 core\n" | ||
|  |     "layout(std430) buffer;\n" | ||
|  |     "layout(std430) uniform;\n" | ||
|  |     "layout(set=0, binding=0, rgba16f) writeonly restrict uniform image3D uOutput;\n" | ||
|  |     "layout(set=0, binding=1) uniform sampler2D uInput;\n" | ||
|  |     "layout(set=0, binding=2) uniform sampler2D uBias;\n" | ||
|  |     "layout(set=0, binding=3) readonly restrict uniform constBuffer {\n" | ||
|  |     "    ivec4 inputSize;\n" | ||
|  |     "    ivec4 outputSize;\n" | ||
|  |     "    int padX;\n" | ||
|  |     "    int padY;\n" | ||
|  |     "    int unitWidth;\n" | ||
|  |     "    int unitHeight;\n" | ||
|  |     "    int unit;\n" | ||
|  |     "} uConst;\n" | ||
|  |     "layout(local_size_x = 8, local_size_y = 8, local_size_z = 1) in;\n" | ||
|  |     "void main()\n" | ||
|  |     "{\n" | ||
|  |     "    ivec3 pos = ivec3(gl_GlobalInvocationID);\n" | ||
|  | 
 | ||
|  |     "    if (pos.x < uConst.unitWidth && pos.y < uConst.unitHeight)\n" | ||
|  |     "    {\n" | ||
|  |     "int dstWidth = (uConst.unitWidth*uConst.unitHeight+3)/4;\n" | ||
|  |     "int dstXOrigin = uConst.unitWidth * pos.y + pos.x;\n" | ||
|  |     "int dstX = dstXOrigin / 4;\n" | ||
|  |     "int dstY = 4*pos.z + dstXOrigin % 4;\n" | ||
|  |     "        vec4 bias = texelFetch(uBias, ivec2(pos.z, 0), 0);\n" | ||
|  |     "int oyStart = pos.y * uConst.unit;\n" | ||
|  |     "int oxStart = pos.x * uConst.unit;\n" | ||
|  |     "int oz = pos.z;\n"; | ||
|  | 
 | ||
|  | const char* gWinogradDestTail = | ||
|  |     "    }\n" | ||
|  |     "}\n"; | ||
|  | 
 | ||
|  | int main(int argc, const char* argv[]) { | ||
|  |     int unit       = atoi(argv[1]); | ||
|  |     int kernelSize = atoi(argv[2]); | ||
|  |     auto alpha     = unit + kernelSize - 1; | ||
|  |     float interp   = 0.5f; | ||
|  |     if (argc > 3) { | ||
|  |         interp = atof(argv[3]); | ||
|  |     } | ||
|  |     MNN::Math::WinogradGenerater generater(unit, kernelSize, interp); | ||
|  |     auto a = generater.A(); | ||
|  |     auto b = generater.B(); | ||
|  |     auto g = generater.G(); | ||
|  | 
 | ||
|  |     MNN::Math::Matrix::print(a.get(), "A"); | ||
|  |     MNN::Math::Matrix::print(b.get(), "B"); | ||
|  |     MNN::Math::Matrix::print(g.get(), "G"); | ||
|  |     std::ostringstream sourceFileOstream; | ||
|  |     { sourceFileOstream << "winogradTransformSource" << unit << "_" << kernelSize << "_" << interp << ".comp"; } | ||
|  |     auto sourceFile = sourceFileOstream.str(); | ||
|  |     MNN_PRINT("Generate %s\n", sourceFile.c_str()); | ||
|  |     { | ||
|  |         std::ofstream sourceOutput(sourceFile.c_str()); | ||
|  |         sourceOutput << gWinogradSourceHead << "{\n"; | ||
|  | 
 | ||
|  |         // Load
 | ||
|  |         for (int y = 0; y < alpha; ++y) { | ||
|  |             for (int x = 0; x < alpha; ++x) { | ||
|  |                 sourceOutput << "vec4 S" << x << y << "= texelFetch(uInput, ivec3(sxStart+" << x << ", syStart+ " << y | ||
|  |                              << ", pos.z), 0);\n"; | ||
|  |             } | ||
|  |         } | ||
|  | 
 | ||
|  |         // M = BT*S
 | ||
|  |         auto bFloat = b->host<float>(); | ||
|  |         for (int y = 0; y < alpha; ++y) { | ||
|  |             for (int x = 0; x < alpha; ++x) { | ||
|  |                 sourceOutput << "vec4 m" << x << y << "= "; | ||
|  | 
 | ||
|  |                 for (int k = 0; k < alpha; ++k) { | ||
|  |                     auto value = bFloat[alpha * k + y]; | ||
|  |                     if (0.0f == value) { | ||
|  |                         continue; | ||
|  |                     } else if (1.0f == value) { | ||
|  |                         sourceOutput << "+S" << x << k; | ||
|  |                     } else if (-1.0f == value) { | ||
|  |                         sourceOutput << "-S" << x << k; | ||
|  |                     } else { | ||
|  |                         if (value > 0) { | ||
|  |                             sourceOutput << "+"; | ||
|  |                         } | ||
|  |                         sourceOutput << value << "*S" << x << k; | ||
|  |                     } | ||
|  |                 } | ||
|  |                 sourceOutput << ";\n"; | ||
|  |             } | ||
|  |         } | ||
|  | 
 | ||
|  |         // S = M*B
 | ||
|  |         for (int y = 0; y < alpha; ++y) { | ||
|  |             for (int x = 0; x < alpha; ++x) { | ||
|  |                 sourceOutput << "imageStore(uOutput, ivec2(dstX, dstY+srcHeight*" << (y * alpha + x) << "), "; | ||
|  |                 for (int k = 0; k < alpha; ++k) { | ||
|  |                     auto value = bFloat[alpha * k + x]; | ||
|  |                     if (0.0f == value) { | ||
|  |                         continue; | ||
|  |                     } else if (1.0f == value) { | ||
|  |                         sourceOutput << "+m" << k << y; | ||
|  |                     } else if (-1.0f == value) { | ||
|  |                         sourceOutput << "-m" << k << y; | ||
|  |                     } else { | ||
|  |                         if (value > 0) { | ||
|  |                             sourceOutput << "+"; | ||
|  |                         } | ||
|  |                         sourceOutput << value << "*m" << k << y; | ||
|  |                     } | ||
|  |                 } | ||
|  |                 sourceOutput << ");\n"; | ||
|  |             } | ||
|  |         } | ||
|  | 
 | ||
|  |         sourceOutput << "}\n"; | ||
|  |         sourceOutput << gWinogradSourceTail; | ||
|  |     } | ||
|  | 
 | ||
|  |     std::ostringstream destFileOstream; | ||
|  |     { destFileOstream << "winogradTransformDest" << unit << "_" << kernelSize << "_" << interp << ".comp"; } | ||
|  |     auto destFile = destFileOstream.str(); | ||
|  |     MNN_PRINT("Generate %s\n", destFile.c_str()); | ||
|  |     { | ||
|  |         std::ofstream destFileOs(destFile.c_str()); | ||
|  |         destFileOs << gWinogradDestHead; | ||
|  |         destFileOs << "{\n"; | ||
|  |         auto aFloat = a->host<float>(); | ||
|  | 
 | ||
|  |         // Load
 | ||
|  |         for (int y = 0; y < alpha; ++y) { | ||
|  |             for (int x = 0; x < alpha; ++x) { | ||
|  |                 destFileOs << "vec4 S" << x << y << "= texelFetch(uInput, ivec2(dstX+dstWidth*" << (x + y * alpha) | ||
|  |                            << ", dstY), 0);\n"; | ||
|  |             } | ||
|  |         } | ||
|  | 
 | ||
|  |         // M = AT* S
 | ||
|  |         for (int y = 0; y < unit; ++y) { | ||
|  |             for (int x = 0; x < alpha; ++x) { | ||
|  |                 destFileOs << "vec4 m" << x << y << "= "; | ||
|  |                 for (int k = 0; k < alpha; ++k) { | ||
|  |                     auto value = aFloat[k * unit + y]; | ||
|  |                     if (0.0f == value) { | ||
|  |                         continue; | ||
|  |                     } else if (1.0f == value) { | ||
|  |                         destFileOs << "+S" << x << k; | ||
|  |                     } else if (-1.0f == value) { | ||
|  |                         destFileOs << "-S" << x << k; | ||
|  |                     } else { | ||
|  |                         if (value > 0) { | ||
|  |                             destFileOs << "+"; | ||
|  |                         } | ||
|  |                         destFileOs << value << "*S" << x << k; | ||
|  |                     } | ||
|  |                 } | ||
|  |                 destFileOs << ";\n"; | ||
|  |             } | ||
|  |         } | ||
|  | 
 | ||
|  |         // S = M * A
 | ||
|  |         for (int y = 0; y < unit; ++y) { | ||
|  |             for (int x = 0; x < unit; ++x) { | ||
|  |                 destFileOs << "{\n"; | ||
|  |                 destFileOs << "vec4 res = bias"; | ||
|  |                 for (int k = 0; k < alpha; ++k) { | ||
|  |                     auto value = aFloat[k * unit + x]; | ||
|  |                     if (0.0f == value) { | ||
|  |                         continue; | ||
|  |                     } else if (1.0f == value) { | ||
|  |                         destFileOs << "+m" << k << y; | ||
|  |                     } else if (-1.0f == value) { | ||
|  |                         destFileOs << "-m" << k << y; | ||
|  |                     } else { | ||
|  |                         if (value > 0) { | ||
|  |                             destFileOs << "+"; | ||
|  |                         } | ||
|  |                         destFileOs << value << "*m" << k << y; | ||
|  |                     } | ||
|  |                 } | ||
|  |                 destFileOs << ";\n"; | ||
|  | 
 | ||
|  |                 destFileOs << "#ifdef RELU\n"; | ||
|  |                 destFileOs << "res = max(res, vec4(0));\n"; | ||
|  |                 destFileOs << "#endif\n"; | ||
|  |                 destFileOs << "#ifdef RELU6\n"; | ||
|  |                 destFileOs << "res = clamp(res, vec4(0), vec4(6));\n"; | ||
|  |                 destFileOs << "#endif\n"; | ||
|  |                 destFileOs << "imageStore(uOutput, ivec3(oxStart+" << x << ", oyStart+" << y << ", pos.z), res);\n"; | ||
|  | 
 | ||
|  |                 destFileOs << "}\n"; | ||
|  |             } | ||
|  |         } | ||
|  | 
 | ||
|  |         destFileOs << "}\n"; | ||
|  |         destFileOs << gWinogradDestTail; | ||
|  |     } | ||
|  | 
 | ||
|  |     return 0; | ||
|  | } |