mirror of https://github.com/alibaba/MNN.git
				
				
				
			[PATCH 07/28] add fp16 conv and workaround strided slice
This commit is contained in:
		
							parent
							
								
									f097ad2b5e
								
							
						
					
					
						commit
						ad793bc6ab
					
				|  | @ -182,7 +182,45 @@ VARP _Conv(std::vector<float>&& weight, std::vector<float>&& bias, VARP x, INTS | |||
|     conv2D->bias = std::move(bias); | ||||
|     return (Variable::create(Expr::create(convOp.get(), {x}))); | ||||
| } | ||||
| 
 | ||||
| VARP _Conv(std::vector<__fp16>weight, std::vector<float>&& bias, VARP x, INTS channel, INTS kernelSize, | ||||
|            PaddingMode pad, INTS stride, INTS dilate, int group, INTS pads, bool relu, bool relu6) { | ||||
|     std::unique_ptr<OpT> convOp(new OpT); | ||||
|     convOp->type = OpType_Convolution; | ||||
|     if (channel[0] == channel[1] && channel[0] == group) { | ||||
|         convOp->type = OpType_ConvolutionDepthwise; | ||||
|     } | ||||
|     convOp->main.type  = OpParameter_Convolution2D; | ||||
|     convOp->main.value = new Convolution2DT; | ||||
|     auto conv2D        = convOp->main.AsConvolution2D(); | ||||
|     conv2D->common.reset(new Convolution2DCommonT); | ||||
|     conv2D->common->padMode     = _convertPadMode(pad); | ||||
|     if (pads.size() == 2) { | ||||
|         conv2D->common->padX        = pads[0]; | ||||
|         conv2D->common->padY        = pads[1]; | ||||
|     } else { | ||||
|         conv2D->common->pads = std::move(pads); | ||||
|     } | ||||
|     conv2D->common->strideX     = stride[0]; | ||||
|     conv2D->common->strideY     = stride[1]; | ||||
|     conv2D->common->group       = group; | ||||
|     conv2D->common->outputCount = channel[1]; | ||||
|     conv2D->common->inputCount  = channel[0]; | ||||
|     conv2D->common->dilateX     = dilate[0]; | ||||
|     conv2D->common->dilateY     = dilate[1]; | ||||
|     conv2D->common->kernelX     = kernelSize[0]; | ||||
|     conv2D->common->kernelY     = kernelSize[1]; | ||||
|     conv2D->common->relu6 = relu6; | ||||
|     conv2D->common->relu = relu; | ||||
|     MNN_ASSERT(weight.size() == channel[1] * (channel[0] / group) * kernelSize[0] * kernelSize[1]); | ||||
|     conv2D->quanParameter.reset(new IDSTQuanT); | ||||
|     conv2D->quanParameter->type = 3; | ||||
|     int8_t* halfweight = reinterpret_cast<int8_t*>(weight.data()); | ||||
|     conv2D->quanParameter->buffer.assign(halfweight, halfweight + weight.size() * sizeof(__fp16)); | ||||
|     conv2D->weight.clear(); | ||||
|     MNN_ASSERT(bias.size() == channel[1]); | ||||
|     conv2D->bias = std::move(bias); | ||||
|     return (Variable::create(Expr::create(convOp.get(), {x}))); | ||||
| } | ||||
| VARP _Conv(float weight, float bias, VARP x, INTS channel, INTS kernelSize, PaddingMode pad, INTS stride, INTS dilate, | ||||
|            int group) { | ||||
|     std::unique_ptr<OpT> convOp(new OpT); | ||||
|  | @ -471,20 +509,20 @@ VARP _Slice(VARP x, VARP starts, VARP sizes) { | |||
|     return (Variable::create(Expr::create(slice.get(), {x, starts, sizes}))); | ||||
| } | ||||
| 
 | ||||
| VARP _StridedSlice(VARP x, VARP begin, VARP end, VARP strided, halide_type_t type, int32_t beginMask, | ||||
| VARP _StridedSlice(VARP input, VARP begin, VARP end, VARP strided, int32_t beginMask, | ||||
|                    int32_t endMask, int32_t ellipsisMask, int32_t newAxisMask, int32_t shrinkAxisMask) { | ||||
|     std::unique_ptr<OpT> op(new OpT); | ||||
|     op->type                        = OpType_StridedSlice; | ||||
|     op->main.type                   = OpParameter_StridedSliceParam; | ||||
|     op->main.value                  = new StridedSliceParamT; | ||||
| 
 | ||||
|     op->main.AsStridedSliceParam()->T              = (MNN::DataType)Utils::convertDataType(type);; | ||||
|     op->main.AsStridedSliceParam()->T = DataType_DT_FLOAT; | ||||
|     op->main.AsStridedSliceParam()->beginMask      = beginMask; | ||||
|     op->main.AsStridedSliceParam()->endMask        = endMask; | ||||
|     op->main.AsStridedSliceParam()->ellipsisMask   = ellipsisMask; | ||||
|     op->main.AsStridedSliceParam()->newAxisMask    = newAxisMask; | ||||
|     op->main.AsStridedSliceParam()->shrinkAxisMask = shrinkAxisMask; | ||||
|     return (Variable::create(Expr::create(op.get(), {x, begin, end, strided}))); | ||||
|     return (Variable::create(Expr::create(op.get(), {input, begin, end, strided}))); | ||||
| } | ||||
| /*Transposes x.
 | ||||
| Args: | ||||
|  |  | |||
|  | @ -33,6 +33,8 @@ MNN_PUBLIC VARP _Conv(VARP weight, VARP bias, VARP x, PaddingMode pad = VALID, I | |||
| 
 | ||||
| MNN_PUBLIC VARP _Conv(float weight, float bias, VARP x, INTS channel, INTS kernelSize, PaddingMode pad = VALID, | ||||
|                       INTS stride = {1, 1}, INTS dilate = {1, 1}, int group = 1); | ||||
| MNN_PUBLIC VARP _Conv(std::vector<__fp16> weight, std::vector<float>&& bias, VARP x, INTS channel, INTS kernelSize, | ||||
|                       PaddingMode pad = VALID, INTS stride = {1, 1}, INTS dilate = {1, 1}, int group = 1, INTS pads = {0, 0}, bool relu = false, bool relu6 = false); | ||||
| MNN_PUBLIC VARP _Conv(std::vector<float>&& weight, std::vector<float>&& bias, VARP x, INTS channel, INTS kernelSize, | ||||
|                       PaddingMode pad = VALID, INTS stride = {1, 1}, INTS dilate = {1, 1}, int group = 1, INTS pads = {0, 0}, bool relu = false, bool relu6 = false); | ||||
| MNN_PUBLIC VARP _Deconv(VARP weight, VARP bias, VARP x, PaddingMode pad = VALID, INTS stride = {1, 1}, | ||||
|  | @ -51,7 +53,7 @@ MNN_PUBLIC VARP _Softplus(VARP features); | |||
| MNN_PUBLIC VARP _Softsign(VARP features); | ||||
| MNN_PUBLIC std::vector<VARP> _Split(VARP value, INTS size_splits, int axis = 0); | ||||
| MNN_PUBLIC VARP _Slice(VARP x, VARP starts, VARP sizes); | ||||
| MNN_PUBLIC VARP _StridedSlice(VARP x, VARP begin, VARP end, VARP strided, halide_type_t type, | ||||
| MNN_PUBLIC VARP _StridedSlice(VARP input, VARP begin, VARP end, VARP strided, | ||||
|                                       int32_t beginMask, int32_t endMask, int32_t ellipsisMask, | ||||
|                                       int32_t newAxisMask, int32_t shrinkAxisMask); | ||||
| MNN_PUBLIC VARP _Concat(VARPS values, int axis); | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue