mirror of https://github.com/alibaba/MNN.git
				
				
				
			
		
			
				
	
	
		
			282 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			C++
		
	
	
	
			
		
		
	
	
			282 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			C++
		
	
	
	
| //
 | |
| //  winogradGenerateCL.cpp
 | |
| //  MNN
 | |
| //
 | |
| //  Created by MNN on 2019/01/22.
 | |
| //  Copyright © 2018, Alibaba Group Holding Limited
 | |
| //
 | |
| 
 | |
| #include <math.h>
 | |
| #include <string.h>
 | |
| #include <fstream>
 | |
| #include <sstream>
 | |
| #include <MNN/MNNDefine.h>
 | |
| #include "math/Matrix.hpp"
 | |
| #include "math/WingoradGenerater.hpp"
 | |
| 
 | |
| using namespace std;
 | |
| 
 | |
| const char* gWinogradSourceHead =
 | |
|     "#include <common.h>\n"
 | |
| 
 | |
|     "__kernel void winogradTransformSource\n"
 | |
|     "(\n"
 | |
|     "__read_only image2d_t uInput, //0\n"
 | |
|     "__write_only image2d_t uOutput,\n"
 | |
|     "__private const int unitWidth,\n"
 | |
|     "__private const int unitHeight,//3\n"
 | |
|     "__private const int padX,\n"
 | |
|     "__private const int padY,\n"
 | |
|     "__private const int srcWidth,//6\n"
 | |
|     "__private const int srcHeight,\n"
 | |
|     "__private const int srcChannelC4,\n"
 | |
|     "__private const int offsetX,//9\n"
 | |
|     "__private const int offsetY\n"
 | |
|     ")\n"
 | |
|     "{\n"
 | |
|     "int3 pos = (int3)(get_global_id(0), get_global_id(1), get_global_id(2));\n"
 | |
|     "if (pos.x < unitWidth && pos.y < unitHeight)\n"
 | |
|     "{\n"
 | |
|     "int2 realPos = (int2)(pos.x+offsetX, pos.y+offsetY);\n"
 | |
|     "int dstXOrigin = pos.z;\n"
 | |
|     "int batchIndex = pos.z / srcChannelC4;\n"
 | |
|     "int srcZ = pos.z % srcChannelC4;\n"
 | |
|     "int dstYOrigin = unitWidth * pos.y + pos.x;\n"
 | |
|     "int dstHeight = (unitWidth*unitHeight+3)/4;\n"
 | |
|     "int dstY = dstYOrigin / 4;\n"
 | |
|     "int dstX = dstYOrigin % 4 + 4*dstXOrigin;\n";
 | |
| 
 | |
| const char* gWinogradSourceTail =
 | |
|     "    }\n"
 | |
|     "}\n";
 | |
| 
 | |
| const char* gWinogradDestHead =
 | |
|     "#include <common.h>\n"
 | |
|     "__kernel void winogradTransformDest\n"
 | |
|     "(\n"
 | |
|     "__read_only image2d_t uInput, //0\n"
 | |
|     "__read_only image2d_t uBias,\n"
 | |
|     "__write_only image2d_t uOutput,\n"
 | |
|     "__private const int unitWidth,//3\n"
 | |
|     "__private const int unitHeight,\n"
 | |
|     "__private const int dstWidth,\n"
 | |
|     "__private const int dstHeight,//6\n"
 | |
|     "__private const int dstChannelC4,\n"
 | |
|     "__private const int offsetX,\n"
 | |
|     "__private const int offsetY\n"
 | |
|     ")\n"
 | |
|     "{\n"
 | |
|     "int3 pos = (int3)(get_global_id(0), get_global_id(1), get_global_id(2));\n"
 | |
|     "if (pos.x < unitWidth && pos.y < unitHeight)\n"
 | |
|     "{\n"
 | |
|     "int2 realPos = (int2)(pos.x+offsetX, pos.y+offsetY);\n"
 | |
|     "int srcWidth = (unitWidth*unitHeight+3)/4;\n"
 | |
|     "int dstXOrigin = unitWidth * pos.y + pos.x;\n"
 | |
|     "int dstX = dstXOrigin / 4;\n"
 | |
|     "int dstY = 4*pos.z + dstXOrigin % 4;\n"
 | |
|     "int oz = pos.z % dstChannelC4;\n"
 | |
|     "DATA_TYPE4 bias = READ_IMAGE(uBias, SAMPLER, (int2)(oz, 0));\n"
 | |
|     "int batchIndex = pos.z / dstChannelC4;\n";
 | |
| 
 | |
| const char* gWinogradDestTail =
 | |
|     "    }\n"
 | |
|     "}\n";
 | |
| inline void _printFloat(ostream& output, float v) {
 | |
|     if (v == floor(v)) {
 | |
|         output << v << ".0h";
 | |
|     } else {
 | |
|         output << v << "h";
 | |
|     }
 | |
| }
 | |
| 
 | |
| int main(int argc, const char* argv[]) {
 | |
|     int unit       = atoi(argv[1]);
 | |
|     int kernelSize = atoi(argv[2]);
 | |
|     auto alpha     = unit + kernelSize - 1;
 | |
|     float interp   = 0.5f;
 | |
|     if (argc > 3) {
 | |
|         interp = atof(argv[3]);
 | |
|     }
 | |
|     MNN::Math::WinogradGenerater generater(unit, kernelSize, interp);
 | |
|     auto a = generater.A();
 | |
|     auto b = generater.B();
 | |
|     auto g = generater.G();
 | |
| 
 | |
|     MNN::Math::Matrix::print(a.get(), "A");
 | |
|     MNN::Math::Matrix::print(b.get(), "B");
 | |
|     MNN::Math::Matrix::print(g.get(), "G");
 | |
|     std::ostringstream sourceFileOstream;
 | |
|     { sourceFileOstream << "winogradTransformSource" << unit << "_" << kernelSize << "_" << interp << ".cl"; }
 | |
|     auto sourceFile = sourceFileOstream.str();
 | |
|     MNN_PRINT("Generate %s\n", sourceFile.c_str());
 | |
|     {
 | |
|         std::ofstream sourceOutput(sourceFile.c_str());
 | |
|         sourceOutput << gWinogradSourceHead << "{\n";
 | |
|         sourceOutput << "int sxStart = (realPos.x)*" << unit << " - padX;\n";
 | |
|         sourceOutput << "int syStart = (realPos.y)*" << unit << " - padY;\n";
 | |
| 
 | |
|         // Load
 | |
|         for (int y = 0; y < alpha; ++y) {
 | |
|             for (int x = 0; x < alpha; ++x) {
 | |
|                 sourceOutput << "DATA_TYPE4 S" << x << y << ";\n";
 | |
|             }
 | |
|         }
 | |
|         for (int y = 0; y < alpha; ++y) {
 | |
|             for (int x = 0; x < alpha; ++x) {
 | |
|                 sourceOutput << "{\n";
 | |
|                 sourceOutput << "int sx = " << x << "+sxStart;\n";
 | |
|                 sourceOutput << "int sy = " << y << "+syStart;\n";
 | |
|                 sourceOutput << "int imageSx = select(sx + srcZ*srcWidth, -1, sx<0 || sx >= srcWidth);\n";
 | |
|                 sourceOutput << "int imageSy = select(batchIndex*srcHeight + sy, -1, sy<0 || sy >= srcHeight);\n";
 | |
| 
 | |
|                 sourceOutput << "S" << x << y << "= READ_IMAGE(uInput, SAMPLER, (int2)(imageSx, imageSy));\n";
 | |
|                 sourceOutput << "}\n";
 | |
|             }
 | |
|         }
 | |
| 
 | |
|         // M = BT*S
 | |
|         auto bFloat = b->host<float>();
 | |
|         for (int y = 0; y < alpha; ++y) {
 | |
|             for (int x = 0; x < alpha; ++x) {
 | |
|                 sourceOutput << "DATA_TYPE4 m" << x << y << "= ";
 | |
| 
 | |
|                 for (int k = 0; k < alpha; ++k) {
 | |
|                     auto value = bFloat[alpha * k + y];
 | |
|                     if (0.0f == value) {
 | |
|                         continue;
 | |
|                     } else if (1.0f == value) {
 | |
|                         sourceOutput << "+S" << x << k;
 | |
|                     } else if (-1.0f == value) {
 | |
|                         sourceOutput << "-S" << x << k;
 | |
|                     } else {
 | |
|                         if (value > 0) {
 | |
|                             sourceOutput << "+";
 | |
|                         }
 | |
|                         _printFloat(sourceOutput, value);
 | |
|                         sourceOutput << "*S" << x << k;
 | |
|                     }
 | |
|                 }
 | |
|                 sourceOutput << ";\n";
 | |
|             }
 | |
|         }
 | |
| 
 | |
|         // S = M*B
 | |
|         for (int y = 0; y < alpha; ++y) {
 | |
|             for (int x = 0; x < alpha; ++x) {
 | |
|                 sourceOutput << "WRITE_IMAGE(uOutput, (int2)(dstX, dstY+dstHeight*" << (y * alpha + x) << "), ";
 | |
|                 for (int k = 0; k < alpha; ++k) {
 | |
|                     auto value = bFloat[alpha * k + x];
 | |
|                     if (0.0f == value) {
 | |
|                         continue;
 | |
|                     } else if (1.0f == value) {
 | |
|                         sourceOutput << "+m" << k << y;
 | |
|                     } else if (-1.0f == value) {
 | |
|                         sourceOutput << "-m" << k << y;
 | |
|                     } else {
 | |
|                         if (value > 0) {
 | |
|                             sourceOutput << "+";
 | |
|                         }
 | |
|                         _printFloat(sourceOutput, value);
 | |
| 
 | |
|                         sourceOutput << "*m" << k << y;
 | |
|                     }
 | |
|                 }
 | |
|                 sourceOutput << ");\n";
 | |
|             }
 | |
|         }
 | |
| 
 | |
|         sourceOutput << "}\n";
 | |
|         sourceOutput << gWinogradSourceTail;
 | |
|     }
 | |
| 
 | |
|     std::ostringstream destFileOstream;
 | |
|     { destFileOstream << "winogradTransformDest" << unit << "_" << kernelSize << "_" << interp << ".cl"; }
 | |
|     auto destFile = destFileOstream.str();
 | |
|     MNN_PRINT("Generate %s\n", destFile.c_str());
 | |
|     {
 | |
|         std::ofstream destFileOs(destFile.c_str());
 | |
|         destFileOs << gWinogradDestHead;
 | |
|         destFileOs << "{\n";
 | |
|         auto aFloat = a->host<float>();
 | |
|         destFileOs << " int oyStart = realPos.y * " << unit << ";\n";
 | |
|         destFileOs << " int oxStart = realPos.x * " << unit << ";\n";
 | |
|         // Load
 | |
|         for (int y = 0; y < alpha; ++y) {
 | |
|             for (int x = 0; x < alpha; ++x) {
 | |
|                 destFileOs << "DATA_TYPE4 S" << x << y << "= READ_IMAGE(uInput, SAMPLER, (int2)(dstX+srcWidth*"
 | |
|                            << (x + y * alpha) << ", dstY));\n";
 | |
|             }
 | |
|         }
 | |
| 
 | |
|         // M = AT* S
 | |
|         for (int y = 0; y < unit; ++y) {
 | |
|             for (int x = 0; x < alpha; ++x) {
 | |
|                 destFileOs << "DATA_TYPE4 m" << x << y << "= ";
 | |
|                 for (int k = 0; k < alpha; ++k) {
 | |
|                     auto value = aFloat[k * unit + y];
 | |
|                     if (0.0f == value) {
 | |
|                         continue;
 | |
|                     } else if (1.0f == value) {
 | |
|                         destFileOs << "+S" << x << k;
 | |
|                     } else if (-1.0f == value) {
 | |
|                         destFileOs << "-S" << x << k;
 | |
|                     } else {
 | |
|                         if (value > 0) {
 | |
|                             destFileOs << "+";
 | |
|                         }
 | |
|                         _printFloat(destFileOs, value);
 | |
|                         destFileOs << "*S" << x << k;
 | |
|                     }
 | |
|                 }
 | |
|                 destFileOs << ";\n";
 | |
|             }
 | |
|         }
 | |
| 
 | |
|         // S = M * A
 | |
|         for (int y = 0; y < unit; ++y) {
 | |
|             for (int x = 0; x < unit; ++x) {
 | |
|                 destFileOs << "{\n";
 | |
|                 destFileOs << "int ox = oxStart + " << x << ";\n";
 | |
|                 destFileOs << "int oy = oyStart + " << y << ";\n";
 | |
|                 destFileOs << "if (ox < dstWidth && oy < dstHeight)\n{\n";
 | |
|                 destFileOs << "int imageOx = ox+oz*dstWidth;\n";
 | |
|                 destFileOs << "int imageOy = oy+batchIndex*dstHeight;\n";
 | |
|                 destFileOs << "DATA_TYPE4 res = bias";
 | |
|                 for (int k = 0; k < alpha; ++k) {
 | |
|                     auto value = aFloat[k * unit + x];
 | |
|                     if (0.0f == value) {
 | |
|                         continue;
 | |
|                     } else if (1.0f == value) {
 | |
|                         destFileOs << "+m" << k << y;
 | |
|                     } else if (-1.0f == value) {
 | |
|                         destFileOs << "-m" << k << y;
 | |
|                     } else {
 | |
|                         if (value > 0) {
 | |
|                             destFileOs << "+";
 | |
|                         }
 | |
|                         _printFloat(destFileOs, value);
 | |
|                         destFileOs << "*m" << k << y;
 | |
|                     }
 | |
|                 }
 | |
|                 destFileOs << ";\n";
 | |
| 
 | |
|                 destFileOs << "#ifdef RELU\n";
 | |
|                 destFileOs << "res = max(res, (DATA_TYPE4)(0));\n";
 | |
|                 destFileOs << "#endif\n";
 | |
|                 destFileOs << "#ifdef RELU6\n";
 | |
|                 destFileOs << "res = clamp(res, (DATA_TYPE4)(0), (DATA_TYPE4)(6));\n";
 | |
|                 destFileOs << "#endif\n";
 | |
|                 destFileOs << "WRITE_IMAGE(uOutput, (int2)(imageOx, imageOy), res);\n";
 | |
|                 destFileOs << "}\n";
 | |
| 
 | |
|                 destFileOs << "}\n";
 | |
|             }
 | |
|         }
 | |
| 
 | |
|         destFileOs << "}\n";
 | |
|         destFileOs << gWinogradDestTail;
 | |
|     }
 | |
| 
 | |
|     return 0;
 | |
| }
 |