2019-04-17 10:49:11 +08:00
|
|
|
//
|
2021-03-12 18:41:50 +08:00
|
|
|
// ConvBufWinograd.hpp
|
2019-04-17 10:49:11 +08:00
|
|
|
// MNN
|
|
|
|
//
|
|
|
|
// Created by MNN on 2019/02/01.
|
|
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
|
|
//
|
|
|
|
|
2021-03-12 18:41:50 +08:00
|
|
|
#ifndef MNN_OPENCL_BUFFER_CLOSED
|
|
|
|
|
|
|
|
#ifndef __CONVBUF_WINOGRAD__
|
|
|
|
#define __CONVBUF_WINOGRAD__
|
2019-04-17 10:49:11 +08:00
|
|
|
|
2019-12-27 22:16:57 +08:00
|
|
|
#include "core/Execution.hpp"
|
2019-04-17 10:49:11 +08:00
|
|
|
|
|
|
|
#include <array>
|
|
|
|
#include <memory>
|
|
|
|
#include <vector>
|
2021-03-12 18:41:50 +08:00
|
|
|
#include "backend/opencl/execution/buffer/ConvBufExecution.hpp"
|
|
|
|
#include "backend/opencl/core/OpenCLRunningUtils.hpp"
|
|
|
|
|
2019-04-17 10:49:11 +08:00
|
|
|
namespace MNN {
|
|
|
|
namespace OpenCL {
|
2021-03-12 18:41:50 +08:00
|
|
|
class ConvBufWinograd : public Execution {
|
2019-04-17 10:49:11 +08:00
|
|
|
public:
|
2021-03-12 18:41:50 +08:00
|
|
|
ConvBufWinograd(const MNN::Convolution2D* op, Backend* backend);
|
|
|
|
virtual ~ConvBufWinograd();
|
2019-04-17 10:49:11 +08:00
|
|
|
|
|
|
|
virtual ErrorCode onResize(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) override;
|
|
|
|
virtual ErrorCode onExecute(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) override;
|
2023-03-20 11:32:29 +08:00
|
|
|
static bool valid(const Convolution2DCommon* common, const Tensor* input, const Tensor* output, int limit = 8192);
|
2020-07-06 17:48:55 +08:00
|
|
|
std::vector<uint32_t> getLocalWS(std::string kernelName, int index, std::vector<uint32_t> &gws, const uint32_t maxWorkGroupSize, cl::Kernel mKernel);
|
2019-04-17 10:49:11 +08:00
|
|
|
|
|
|
|
private:
|
|
|
|
OpenCLBackend* mOpenCLBackend;
|
|
|
|
const Convolution2DCommon* mCommon;
|
|
|
|
int mKernelX;
|
|
|
|
int mKernelY;
|
|
|
|
int mStrideX;
|
|
|
|
int mStrideY;
|
2021-03-12 18:41:50 +08:00
|
|
|
std::shared_ptr<Tensor> mWeight;
|
|
|
|
std::shared_ptr<Tensor> mBias;
|
2019-04-17 10:49:11 +08:00
|
|
|
|
|
|
|
std::shared_ptr<Tensor> mSource;
|
|
|
|
std::shared_ptr<Tensor> mDest;
|
|
|
|
|
2020-06-23 14:02:09 +08:00
|
|
|
std::vector<cl::Kernel> mSourceTransform;
|
|
|
|
std::vector<cl::Kernel> mDestTransform;
|
|
|
|
std::vector<cl::Kernel> mMatMul;
|
2019-05-14 19:54:21 +08:00
|
|
|
|
2020-06-23 14:02:09 +08:00
|
|
|
std::vector<uint32_t> mMaxWGS_S;
|
|
|
|
std::vector<uint32_t> mMaxWGS_D;
|
|
|
|
std::vector<uint32_t> mMaxWGS_M;
|
|
|
|
|
|
|
|
std::vector<std::vector<uint32_t> > mGWS_S;
|
|
|
|
std::vector<std::vector<uint32_t> > mGWS_D;
|
|
|
|
std::vector<std::vector<uint32_t> > mGWS_M;
|
|
|
|
|
|
|
|
std::vector<std::vector<uint32_t> > mLWS_S;
|
|
|
|
std::vector<std::vector<uint32_t> > mLWS_D;
|
|
|
|
std::vector<std::vector<uint32_t> > mLWS_M;
|
2019-04-17 10:49:11 +08:00
|
|
|
};
|
|
|
|
|
|
|
|
} // namespace OpenCL
|
|
|
|
} // namespace MNN
|
|
|
|
|
2021-03-12 18:41:50 +08:00
|
|
|
#endif /* __CONVBUF_WINOGRAD__ */
|
|
|
|
#endif /* MNN_OPENCL_BUFFER_CLOSED */
|