[PATCH 07/28] add fp16 conv and workaround strided slice

This commit is contained in:
如幻 2020-03-23 09:32:02 +08:00 committed by xiaying
parent f097ad2b5e
commit ad793bc6ab
2 changed files with 45 additions and 5 deletions

View File

@ -182,7 +182,45 @@ VARP _Conv(std::vector<float>&& weight, std::vector<float>&& bias, VARP x, INTS
conv2D->bias = std::move(bias);
return (Variable::create(Expr::create(convOp.get(), {x})));
}
VARP _Conv(std::vector<__fp16>weight, std::vector<float>&& bias, VARP x, INTS channel, INTS kernelSize,
PaddingMode pad, INTS stride, INTS dilate, int group, INTS pads, bool relu, bool relu6) {
std::unique_ptr<OpT> convOp(new OpT);
convOp->type = OpType_Convolution;
if (channel[0] == channel[1] && channel[0] == group) {
convOp->type = OpType_ConvolutionDepthwise;
}
convOp->main.type = OpParameter_Convolution2D;
convOp->main.value = new Convolution2DT;
auto conv2D = convOp->main.AsConvolution2D();
conv2D->common.reset(new Convolution2DCommonT);
conv2D->common->padMode = _convertPadMode(pad);
if (pads.size() == 2) {
conv2D->common->padX = pads[0];
conv2D->common->padY = pads[1];
} else {
conv2D->common->pads = std::move(pads);
}
conv2D->common->strideX = stride[0];
conv2D->common->strideY = stride[1];
conv2D->common->group = group;
conv2D->common->outputCount = channel[1];
conv2D->common->inputCount = channel[0];
conv2D->common->dilateX = dilate[0];
conv2D->common->dilateY = dilate[1];
conv2D->common->kernelX = kernelSize[0];
conv2D->common->kernelY = kernelSize[1];
conv2D->common->relu6 = relu6;
conv2D->common->relu = relu;
MNN_ASSERT(weight.size() == channel[1] * (channel[0] / group) * kernelSize[0] * kernelSize[1]);
conv2D->quanParameter.reset(new IDSTQuanT);
conv2D->quanParameter->type = 3;
int8_t* halfweight = reinterpret_cast<int8_t*>(weight.data());
conv2D->quanParameter->buffer.assign(halfweight, halfweight + weight.size() * sizeof(__fp16));
conv2D->weight.clear();
MNN_ASSERT(bias.size() == channel[1]);
conv2D->bias = std::move(bias);
return (Variable::create(Expr::create(convOp.get(), {x})));
}
VARP _Conv(float weight, float bias, VARP x, INTS channel, INTS kernelSize, PaddingMode pad, INTS stride, INTS dilate,
int group) {
std::unique_ptr<OpT> convOp(new OpT);
@ -471,20 +509,20 @@ VARP _Slice(VARP x, VARP starts, VARP sizes) {
return (Variable::create(Expr::create(slice.get(), {x, starts, sizes})));
}
VARP _StridedSlice(VARP x, VARP begin, VARP end, VARP strided, halide_type_t type, int32_t beginMask,
VARP _StridedSlice(VARP input, VARP begin, VARP end, VARP strided, int32_t beginMask,
int32_t endMask, int32_t ellipsisMask, int32_t newAxisMask, int32_t shrinkAxisMask) {
std::unique_ptr<OpT> op(new OpT);
op->type = OpType_StridedSlice;
op->main.type = OpParameter_StridedSliceParam;
op->main.value = new StridedSliceParamT;
op->main.AsStridedSliceParam()->T = (MNN::DataType)Utils::convertDataType(type);;
op->main.AsStridedSliceParam()->T = DataType_DT_FLOAT;
op->main.AsStridedSliceParam()->beginMask = beginMask;
op->main.AsStridedSliceParam()->endMask = endMask;
op->main.AsStridedSliceParam()->ellipsisMask = ellipsisMask;
op->main.AsStridedSliceParam()->newAxisMask = newAxisMask;
op->main.AsStridedSliceParam()->shrinkAxisMask = shrinkAxisMask;
return (Variable::create(Expr::create(op.get(), {x, begin, end, strided})));
return (Variable::create(Expr::create(op.get(), {input, begin, end, strided})));
}
/*Transposes x.
Args:

View File

@ -33,6 +33,8 @@ MNN_PUBLIC VARP _Conv(VARP weight, VARP bias, VARP x, PaddingMode pad = VALID, I
MNN_PUBLIC VARP _Conv(float weight, float bias, VARP x, INTS channel, INTS kernelSize, PaddingMode pad = VALID,
INTS stride = {1, 1}, INTS dilate = {1, 1}, int group = 1);
MNN_PUBLIC VARP _Conv(std::vector<__fp16> weight, std::vector<float>&& bias, VARP x, INTS channel, INTS kernelSize,
PaddingMode pad = VALID, INTS stride = {1, 1}, INTS dilate = {1, 1}, int group = 1, INTS pads = {0, 0}, bool relu = false, bool relu6 = false);
MNN_PUBLIC VARP _Conv(std::vector<float>&& weight, std::vector<float>&& bias, VARP x, INTS channel, INTS kernelSize,
PaddingMode pad = VALID, INTS stride = {1, 1}, INTS dilate = {1, 1}, int group = 1, INTS pads = {0, 0}, bool relu = false, bool relu6 = false);
MNN_PUBLIC VARP _Deconv(VARP weight, VARP bias, VARP x, PaddingMode pad = VALID, INTS stride = {1, 1},
@ -51,7 +53,7 @@ MNN_PUBLIC VARP _Softplus(VARP features);
MNN_PUBLIC VARP _Softsign(VARP features);
MNN_PUBLIC std::vector<VARP> _Split(VARP value, INTS size_splits, int axis = 0);
MNN_PUBLIC VARP _Slice(VARP x, VARP starts, VARP sizes);
MNN_PUBLIC VARP _StridedSlice(VARP x, VARP begin, VARP end, VARP strided, halide_type_t type,
MNN_PUBLIC VARP _StridedSlice(VARP input, VARP begin, VARP end, VARP strided,
int32_t beginMask, int32_t endMask, int32_t ellipsisMask,
int32_t newAxisMask, int32_t shrinkAxisMask);
MNN_PUBLIC VARP _Concat(VARPS values, int axis);