| 
									
										
										
										
											2019-06-17 20:10:35 +08:00
										 |  |  | //
 | 
					
						
							|  |  |  | //  SelectGrad.cpp
 | 
					
						
							|  |  |  | //  MNN
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | //  Created by MNN on 2019/05/26.
 | 
					
						
							|  |  |  | //  Copyright © 2018, Alibaba Group Holding Limited
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include "SelectGrad.hpp"
 | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  | #include "core/Macro.h"
 | 
					
						
							| 
									
										
										
										
											2019-06-17 20:10:35 +08:00
										 |  |  | using namespace std; | 
					
						
							|  |  |  | using namespace MNN; | 
					
						
							| 
									
										
										
											
												- dynamic computation graph (beta)
	- add supports (/express)
	- add tests
	- add benchmarks with it (/benchmark/exprModels)
- Python
	- MNN engine and tools were submitted to pip
	- available on Windows/macOS/Linux
- Engine/Converter
	- add supports for each op benchmarking
	- refactor optimizer by separating steps
- CPU
	- add supports for Conv3D, Pool3D, ELU, ReverseSequence
	- fix ArgMax, Permute, Scale, BinaryOp, Slice, SliceTf
- OpenCL
	- add half transform in CPU
	- add broadcast supports for binary
	- optimize Conv2D, Reshape, Eltwise, Gemm, etc.
- OpenGL
	- add sub, real div supports for binary
	- add supports for unary
	- optimize Conv2D, Reshape
- Vulkan
	- add max supports for eltwise
- Metal
	- fix metallib missing problem
- Train/Quantization
	- use express to refactor training codes
											
										 
											2019-09-26 21:02:07 +08:00
										 |  |  | using namespace MNN::Express; | 
					
						
							| 
									
										
										
										
											2019-06-17 20:10:35 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | class SelectGrad : public OpGrad { | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  |     SelectGrad() { | 
					
						
							|  |  |  |         mType = SEMI_LINEAR; | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2020-01-15 13:33:47 +08:00
										 |  |  |     virtual std::vector<Express::VARP> onGrad(Express::EXPRP expr, | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  |                                               const std::vector<Express::VARP>& backwardOutput) override { | 
					
						
							| 
									
										
										
											
												- dynamic computation graph (beta)
	- add supports (/express)
	- add tests
	- add benchmarks with it (/benchmark/exprModels)
- Python
	- MNN engine and tools were submitted to pip
	- available on Windows/macOS/Linux
- Engine/Converter
	- add supports for each op benchmarking
	- refactor optimizer by separating steps
- CPU
	- add supports for Conv3D, Pool3D, ELU, ReverseSequence
	- fix ArgMax, Permute, Scale, BinaryOp, Slice, SliceTf
- OpenCL
	- add half transform in CPU
	- add broadcast supports for binary
	- optimize Conv2D, Reshape, Eltwise, Gemm, etc.
- OpenGL
	- add sub, real div supports for binary
	- add supports for unary
	- optimize Conv2D, Reshape
- Vulkan
	- add max supports for eltwise
- Metal
	- fix metallib missing problem
- Train/Quantization
	- use express to refactor training codes
											
										 
											2019-09-26 21:02:07 +08:00
										 |  |  |         auto inputs = expr->inputs(); | 
					
						
							|  |  |  |         std::vector<VARP> result(inputs.size(), nullptr); | 
					
						
							|  |  |  |         auto outputDiff = backwardOutput[0]; | 
					
						
							| 
									
										
										
										
											2019-06-17 20:10:35 +08:00
										 |  |  |         // d (select(x, a, b)) = da * (x>0) + db * (x < 0)
 | 
					
						
							|  |  |  |         { | 
					
						
							|  |  |  |             // Cast x>0 -> float
 | 
					
						
							|  |  |  |             unique_ptr<OpT> mask(new OpT); | 
					
						
							|  |  |  |             mask->type                     = OpType_Cast; | 
					
						
							|  |  |  |             mask->main.type                = OpParameter_CastParam; | 
					
						
							|  |  |  |             mask->main.value               = new CastParamT; | 
					
						
							|  |  |  |             mask->main.AsCastParam()->dstT = DataType_DT_FLOAT; | 
					
						
							|  |  |  |             mask->main.AsCastParam()->srcT = DataType_DT_BOOL; | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
											
												- dynamic computation graph (beta)
	- add supports (/express)
	- add tests
	- add benchmarks with it (/benchmark/exprModels)
- Python
	- MNN engine and tools were submitted to pip
	- available on Windows/macOS/Linux
- Engine/Converter
	- add supports for each op benchmarking
	- refactor optimizer by separating steps
- CPU
	- add supports for Conv3D, Pool3D, ELU, ReverseSequence
	- fix ArgMax, Permute, Scale, BinaryOp, Slice, SliceTf
- OpenCL
	- add half transform in CPU
	- add broadcast supports for binary
	- optimize Conv2D, Reshape, Eltwise, Gemm, etc.
- OpenGL
	- add sub, real div supports for binary
	- add supports for unary
	- optimize Conv2D, Reshape
- Vulkan
	- add max supports for eltwise
- Metal
	- fix metallib missing problem
- Train/Quantization
	- use express to refactor training codes
											
										 
											2019-09-26 21:02:07 +08:00
										 |  |  |             auto maskVar = Variable::create(Expr::create(std::move(mask), {inputs[0]})); | 
					
						
							| 
									
										
										
										
											2019-06-17 20:10:35 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |             // da * (x>0)
 | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  |             result[1] = _Multiply(outputDiff, maskVar); | 
					
						
							| 
									
										
										
										
											2019-06-17 20:10:35 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |             // db * -((x>0)-1)
 | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  |             auto one  = _Const(1.0f); | 
					
						
							|  |  |  |             auto sub  = _Subtract(maskVar, one); | 
					
						
							|  |  |  |             auto neg  = _Negative(sub); | 
					
						
							|  |  |  |             result[2] = _Multiply(outputDiff, neg); | 
					
						
							| 
									
										
										
										
											2019-06-17 20:10:35 +08:00
										 |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return result; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | static const auto gRegister = []() { | 
					
						
							| 
									
										
										
											
												- dynamic computation graph (beta)
	- add supports (/express)
	- add tests
	- add benchmarks with it (/benchmark/exprModels)
- Python
	- MNN engine and tools were submitted to pip
	- available on Windows/macOS/Linux
- Engine/Converter
	- add supports for each op benchmarking
	- refactor optimizer by separating steps
- CPU
	- add supports for Conv3D, Pool3D, ELU, ReverseSequence
	- fix ArgMax, Permute, Scale, BinaryOp, Slice, SliceTf
- OpenCL
	- add half transform in CPU
	- add broadcast supports for binary
	- optimize Conv2D, Reshape, Eltwise, Gemm, etc.
- OpenGL
	- add sub, real div supports for binary
	- add supports for unary
	- optimize Conv2D, Reshape
- Vulkan
	- add max supports for eltwise
- Metal
	- fix metallib missing problem
- Train/Quantization
	- use express to refactor training codes
											
										 
											2019-09-26 21:02:07 +08:00
										 |  |  |     static SelectGrad _c; | 
					
						
							| 
									
										
										
										
											2019-06-17 20:10:35 +08:00
										 |  |  |     OpGrad::insert(OpType_Select, &_c); | 
					
						
							|  |  |  |     return true; | 
					
						
							|  |  |  | }(); |