MNN/source/backend/metal/MetalConvolutionCommon.hpp

59 lines
1.8 KiB
C++
Raw Normal View History

2019-04-17 10:49:11 +08:00
//
// MetalConvolutionCommon.hpp
// MNN
//
// Created by MNN on 2019/02/25.
// Copyright © 2018, Alibaba Group Holding Limited
//
#ifndef MetalConvolutionCommon_hpp
#define MetalConvolutionCommon_hpp
#import "core/ConvolutionCommon.hpp"
2020-11-05 16:41:56 +08:00
#import "MetalBackend.hpp"
2023-12-27 17:26:44 +08:00
#import "MetalExecution.hpp"
2019-04-17 10:49:11 +08:00
#import "MNNMetalContext.h"
#if MNN_METAL_ENABLED
namespace MNN {
2023-12-27 17:26:44 +08:00
class MetalConvolutionCommon : public MetalExecution {
2019-04-17 10:49:11 +08:00
public:
MetalConvolutionCommon(Backend *backend, const MNN::Op *op);
virtual ~MetalConvolutionCommon() = default;
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
2023-12-27 17:26:44 +08:00
virtual void onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) override;
2019-04-17 10:49:11 +08:00
protected:
void loadWeight(const MNN::Convolution2D *conv);
2023-12-27 17:26:44 +08:00
virtual void onFloat(const Tensor *input, const Tensor *output, id<MTLComputeCommandEncoder> encoder) = 0;
2019-04-17 10:49:11 +08:00
virtual id<MTLBuffer> weightForFloat(int group, int oc, int ic, int kh, int kw, const float *src);
private:
id<MTLBuffer> weightForConv(const Convolution2D *, ConvolutionCommon::Int8Common *, bool);
2019-04-17 10:49:11 +08:00
protected:
bool mDepthwise = false;
int mGroups = 0;
int mKernelX = 0;
int mKernelY = 0;
PadMode mPadMode = PadMode_CAFFE;
int mPadX = 0;
int mPadY = 0;
int mStrideX = 0;
int mStrideY = 0;
int mDilateX = 0;
int mDilateY = 0;
int mActivationType = 0;
2021-09-18 15:52:30 +08:00
const MNN::Op *mOp = nullptr;
2019-04-17 10:49:11 +08:00
id<MTLBuffer> mWeight = nil;
id<MTLBuffer> mBias = nil;
id<MTLBuffer> mConstBuffer = nil;
2019-04-17 10:49:11 +08:00
};
} // namespace MNN
#endif /* MNN_METAL_ENABLED */
#endif /* MetalConvolutionCommon_hpp */