MNN/source/backend/cuda/execution/ConvSingleInputExecution.cu

56 lines
1.9 KiB
Plaintext
Raw Normal View History

2020-11-05 16:41:56 +08:00
//
// ConvSingleInputExecution.cpp
// MNN
//
// Created by MNN on 2020/08/22.
// Copyright © 2018, Alibaba Group Holding Limited
//
#include "ConvSingleInputExecution.hpp"
2022-08-12 10:30:48 +08:00
#include "ConvWinogradExecution.hpp"
#include "ConvCutlassExecution.hpp"
2022-09-30 10:02:52 +08:00
#include "backend/cuda/core/CUDATools.hpp"
2020-11-05 16:41:56 +08:00
namespace MNN {
namespace CUDA {
class CUDAConvolutionCreator : public CUDABackend::Creator {
public:
2022-09-30 10:02:52 +08:00
virtual Execution* onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
2020-11-05 16:41:56 +08:00
const MNN::Op* op, Backend* backend) const override {
if (nullptr != op->main_as_Convolution2D()->quanParameter()) {
auto quan = op->main_as_Convolution2D()->quanParameter();
if (1 == quan->type() || 2 == quan->type()) {
if (quan->has_scaleInt()) {
// Don't support IDST-int8 because of error
return nullptr;
}
2020-11-05 16:41:56 +08:00
}
}
2022-08-12 10:30:48 +08:00
2022-09-30 10:02:52 +08:00
#ifdef USE_MNN_CONV
std::shared_ptr<ConvSingleInputExecution::Resource> resource(new ConvSingleInputExecution::Resource(backend, op));
return new ConvSingleInputExecution(backend, op, resource);
#else
2022-08-12 10:30:48 +08:00
auto conv = op->main_as_Convolution2D()->common();
2022-09-30 10:02:52 +08:00
if(ConvWinogradExecution::isValid(op->main_as_Convolution2D())) { // inputs[0] is invalid now.
2022-08-12 10:30:48 +08:00
//printf("%dx%ds%dd%d\n", conv->kernelX(), conv->kernelY(), conv->strideX(), conv->dilateX());
std::shared_ptr<ConvWinogradExecution::Resource> resource(new ConvWinogradExecution::Resource(backend, op));
return new ConvWinogradExecution(backend, op, resource);
}
std::shared_ptr<ConvCutlassExecution::Resource> resource(new ConvCutlassExecution::Resource(backend, op));
return new ConvCutlassExecution(backend, op, resource);
2022-09-30 10:02:52 +08:00
#endif
2020-11-05 16:41:56 +08:00
}
};
CUDACreatorRegister<CUDAConvolutionCreator> __ConvExecution(OpType_Convolution);
}// namespace CUDA
}// namespace MNN