mirror of https://github.com/alibaba/MNN.git
				
				
				
			
		
			
				
	
	
		
			252 lines
		
	
	
		
			8.7 KiB
		
	
	
	
		
			C++
		
	
	
	
			
		
		
	
	
			252 lines
		
	
	
		
			8.7 KiB
		
	
	
	
		
			C++
		
	
	
	
| //
 | |
| //  winogradGenerateGLSL.cpp
 | |
| //  MNN
 | |
| //
 | |
| //  Created by MNN on 2019/01/22.
 | |
| //  Copyright © 2018, Alibaba Group Holding Limited
 | |
| //
 | |
| 
 | |
| #include <string.h>
 | |
| #include <fstream>
 | |
| #include <sstream>
 | |
| #include <MNN/MNNDefine.h>
 | |
| #include "math/Matrix.hpp"
 | |
| #include "math/WingoradGenerater.hpp"
 | |
| 
 | |
| using namespace std;
 | |
| 
 | |
| const char* gWinogradSourceHead =
 | |
|     "#version 450 core\n"
 | |
|     "layout(std430) buffer;\n"
 | |
|     "layout(std430) uniform;\n"
 | |
|     "layout(set=0, binding=0, rgba16f) writeonly restrict uniform image2D uOutput;\n"
 | |
|     "layout(set=0, binding=1) uniform sampler3D uInput;\n"
 | |
|     "layout(set=0, binding=2) readonly restrict uniform constBuffer {\n"
 | |
|     "    ivec4 inputSize;\n"
 | |
|     "    ivec4 outputSize;\n"
 | |
|     "    int padX;\n"
 | |
|     "    int padY;\n"
 | |
|     "    int unitWidth;\n"
 | |
|     "    int unitHeight;\n"
 | |
|     "    int unit;\n"
 | |
|     "} uConst;\n"
 | |
|     "layout(local_size_x = 8, local_size_y = 8, local_size_z = 1) in;\n"
 | |
|     "void main()\n"
 | |
|     "{\n"
 | |
|     "    ivec3 pos = ivec3(gl_GlobalInvocationID);\n"
 | |
|     "    if (pos.x < uConst.unitWidth && pos.y < uConst.unitHeight)\n"
 | |
|     "    {\n"
 | |
|     "int dstXOrigin = pos.z;\n"
 | |
|     "int dstYOrigin = uConst.unitWidth * pos.y + pos.x;\n"
 | |
|     "int srcHeight = (uConst.unitWidth*uConst.unitHeight+3)/4;\n"
 | |
|     "int dstY = dstYOrigin / 4;\n"
 | |
|     "int dstX = dstYOrigin % 4 + 4*dstXOrigin;\n"
 | |
|     "        int sxStart = pos.x*uConst.unit - uConst.padX;\n"
 | |
|     "        int syStart = pos.y*uConst.unit - uConst.padY;\n";
 | |
| 
 | |
| const char* gWinogradSourceTail =
 | |
|     "    }\n"
 | |
|     "}\n";
 | |
| 
 | |
| const char* gWinogradDestHead =
 | |
|     "#version 450 core\n"
 | |
|     "layout(std430) buffer;\n"
 | |
|     "layout(std430) uniform;\n"
 | |
|     "layout(set=0, binding=0, rgba16f) writeonly restrict uniform image3D uOutput;\n"
 | |
|     "layout(set=0, binding=1) uniform sampler2D uInput;\n"
 | |
|     "layout(set=0, binding=2) uniform sampler2D uBias;\n"
 | |
|     "layout(set=0, binding=3) readonly restrict uniform constBuffer {\n"
 | |
|     "    ivec4 inputSize;\n"
 | |
|     "    ivec4 outputSize;\n"
 | |
|     "    int padX;\n"
 | |
|     "    int padY;\n"
 | |
|     "    int unitWidth;\n"
 | |
|     "    int unitHeight;\n"
 | |
|     "    int unit;\n"
 | |
|     "} uConst;\n"
 | |
|     "layout(local_size_x = 8, local_size_y = 8, local_size_z = 1) in;\n"
 | |
|     "void main()\n"
 | |
|     "{\n"
 | |
|     "    ivec3 pos = ivec3(gl_GlobalInvocationID);\n"
 | |
| 
 | |
|     "    if (pos.x < uConst.unitWidth && pos.y < uConst.unitHeight)\n"
 | |
|     "    {\n"
 | |
|     "int dstWidth = (uConst.unitWidth*uConst.unitHeight+3)/4;\n"
 | |
|     "int dstXOrigin = uConst.unitWidth * pos.y + pos.x;\n"
 | |
|     "int dstX = dstXOrigin / 4;\n"
 | |
|     "int dstY = 4*pos.z + dstXOrigin % 4;\n"
 | |
|     "        vec4 bias = texelFetch(uBias, ivec2(pos.z, 0), 0);\n"
 | |
|     "int oyStart = pos.y * uConst.unit;\n"
 | |
|     "int oxStart = pos.x * uConst.unit;\n"
 | |
|     "int oz = pos.z;\n";
 | |
| 
 | |
| const char* gWinogradDestTail =
 | |
|     "    }\n"
 | |
|     "}\n";
 | |
| 
 | |
| int main(int argc, const char* argv[]) {
 | |
|     int unit       = atoi(argv[1]);
 | |
|     int kernelSize = atoi(argv[2]);
 | |
|     auto alpha     = unit + kernelSize - 1;
 | |
|     float interp   = 0.5f;
 | |
|     if (argc > 3) {
 | |
|         interp = atof(argv[3]);
 | |
|     }
 | |
|     MNN::Math::WinogradGenerater generater(unit, kernelSize, interp);
 | |
|     auto a = generater.A();
 | |
|     auto b = generater.B();
 | |
|     auto g = generater.G();
 | |
| 
 | |
|     MNN::Math::Matrix::print(a.get(), "A");
 | |
|     MNN::Math::Matrix::print(b.get(), "B");
 | |
|     MNN::Math::Matrix::print(g.get(), "G");
 | |
|     std::ostringstream sourceFileOstream;
 | |
|     { sourceFileOstream << "winogradTransformSource" << unit << "_" << kernelSize << "_" << interp << ".comp"; }
 | |
|     auto sourceFile = sourceFileOstream.str();
 | |
|     MNN_PRINT("Generate %s\n", sourceFile.c_str());
 | |
|     {
 | |
|         std::ofstream sourceOutput(sourceFile.c_str());
 | |
|         sourceOutput << gWinogradSourceHead << "{\n";
 | |
| 
 | |
|         // Load
 | |
|         for (int y = 0; y < alpha; ++y) {
 | |
|             for (int x = 0; x < alpha; ++x) {
 | |
|                 sourceOutput << "vec4 S" << x << y << "= texelFetch(uInput, ivec3(sxStart+" << x << ", syStart+ " << y
 | |
|                              << ", pos.z), 0);\n";
 | |
|             }
 | |
|         }
 | |
| 
 | |
|         // M = BT*S
 | |
|         auto bFloat = b->host<float>();
 | |
|         for (int y = 0; y < alpha; ++y) {
 | |
|             for (int x = 0; x < alpha; ++x) {
 | |
|                 sourceOutput << "vec4 m" << x << y << "= ";
 | |
| 
 | |
|                 for (int k = 0; k < alpha; ++k) {
 | |
|                     auto value = bFloat[alpha * k + y];
 | |
|                     if (0.0f == value) {
 | |
|                         continue;
 | |
|                     } else if (1.0f == value) {
 | |
|                         sourceOutput << "+S" << x << k;
 | |
|                     } else if (-1.0f == value) {
 | |
|                         sourceOutput << "-S" << x << k;
 | |
|                     } else {
 | |
|                         if (value > 0) {
 | |
|                             sourceOutput << "+";
 | |
|                         }
 | |
|                         sourceOutput << value << "*S" << x << k;
 | |
|                     }
 | |
|                 }
 | |
|                 sourceOutput << ";\n";
 | |
|             }
 | |
|         }
 | |
| 
 | |
|         // S = M*B
 | |
|         for (int y = 0; y < alpha; ++y) {
 | |
|             for (int x = 0; x < alpha; ++x) {
 | |
|                 sourceOutput << "imageStore(uOutput, ivec2(dstX, dstY+srcHeight*" << (y * alpha + x) << "), ";
 | |
|                 for (int k = 0; k < alpha; ++k) {
 | |
|                     auto value = bFloat[alpha * k + x];
 | |
|                     if (0.0f == value) {
 | |
|                         continue;
 | |
|                     } else if (1.0f == value) {
 | |
|                         sourceOutput << "+m" << k << y;
 | |
|                     } else if (-1.0f == value) {
 | |
|                         sourceOutput << "-m" << k << y;
 | |
|                     } else {
 | |
|                         if (value > 0) {
 | |
|                             sourceOutput << "+";
 | |
|                         }
 | |
|                         sourceOutput << value << "*m" << k << y;
 | |
|                     }
 | |
|                 }
 | |
|                 sourceOutput << ");\n";
 | |
|             }
 | |
|         }
 | |
| 
 | |
|         sourceOutput << "}\n";
 | |
|         sourceOutput << gWinogradSourceTail;
 | |
|     }
 | |
| 
 | |
|     std::ostringstream destFileOstream;
 | |
|     { destFileOstream << "winogradTransformDest" << unit << "_" << kernelSize << "_" << interp << ".comp"; }
 | |
|     auto destFile = destFileOstream.str();
 | |
|     MNN_PRINT("Generate %s\n", destFile.c_str());
 | |
|     {
 | |
|         std::ofstream destFileOs(destFile.c_str());
 | |
|         destFileOs << gWinogradDestHead;
 | |
|         destFileOs << "{\n";
 | |
|         auto aFloat = a->host<float>();
 | |
| 
 | |
|         // Load
 | |
|         for (int y = 0; y < alpha; ++y) {
 | |
|             for (int x = 0; x < alpha; ++x) {
 | |
|                 destFileOs << "vec4 S" << x << y << "= texelFetch(uInput, ivec2(dstX+dstWidth*" << (x + y * alpha)
 | |
|                            << ", dstY), 0);\n";
 | |
|             }
 | |
|         }
 | |
| 
 | |
|         // M = AT* S
 | |
|         for (int y = 0; y < unit; ++y) {
 | |
|             for (int x = 0; x < alpha; ++x) {
 | |
|                 destFileOs << "vec4 m" << x << y << "= ";
 | |
|                 for (int k = 0; k < alpha; ++k) {
 | |
|                     auto value = aFloat[k * unit + y];
 | |
|                     if (0.0f == value) {
 | |
|                         continue;
 | |
|                     } else if (1.0f == value) {
 | |
|                         destFileOs << "+S" << x << k;
 | |
|                     } else if (-1.0f == value) {
 | |
|                         destFileOs << "-S" << x << k;
 | |
|                     } else {
 | |
|                         if (value > 0) {
 | |
|                             destFileOs << "+";
 | |
|                         }
 | |
|                         destFileOs << value << "*S" << x << k;
 | |
|                     }
 | |
|                 }
 | |
|                 destFileOs << ";\n";
 | |
|             }
 | |
|         }
 | |
| 
 | |
|         // S = M * A
 | |
|         for (int y = 0; y < unit; ++y) {
 | |
|             for (int x = 0; x < unit; ++x) {
 | |
|                 destFileOs << "{\n";
 | |
|                 destFileOs << "vec4 res = bias";
 | |
|                 for (int k = 0; k < alpha; ++k) {
 | |
|                     auto value = aFloat[k * unit + x];
 | |
|                     if (0.0f == value) {
 | |
|                         continue;
 | |
|                     } else if (1.0f == value) {
 | |
|                         destFileOs << "+m" << k << y;
 | |
|                     } else if (-1.0f == value) {
 | |
|                         destFileOs << "-m" << k << y;
 | |
|                     } else {
 | |
|                         if (value > 0) {
 | |
|                             destFileOs << "+";
 | |
|                         }
 | |
|                         destFileOs << value << "*m" << k << y;
 | |
|                     }
 | |
|                 }
 | |
|                 destFileOs << ";\n";
 | |
| 
 | |
|                 destFileOs << "#ifdef RELU\n";
 | |
|                 destFileOs << "res = max(res, vec4(0));\n";
 | |
|                 destFileOs << "#endif\n";
 | |
|                 destFileOs << "#ifdef RELU6\n";
 | |
|                 destFileOs << "res = clamp(res, vec4(0), vec4(6));\n";
 | |
|                 destFileOs << "#endif\n";
 | |
|                 destFileOs << "imageStore(uOutput, ivec3(oxStart+" << x << ", oyStart+" << y << ", pos.z), res);\n";
 | |
| 
 | |
|                 destFileOs << "}\n";
 | |
|             }
 | |
|         }
 | |
| 
 | |
|         destFileOs << "}\n";
 | |
|         destFileOs << gWinogradDestTail;
 | |
|     }
 | |
| 
 | |
|     return 0;
 | |
| }
 |