mirror of https://github.com/alibaba/MNN.git
				
				
				
			
		
			
				
	
	
		
			67 lines
		
	
	
		
			1.9 KiB
		
	
	
	
		
			C++
		
	
	
	
			
		
		
	
	
			67 lines
		
	
	
		
			1.9 KiB
		
	
	
	
		
			C++
		
	
	
	
//
 | 
						|
//  TRTReduce.cpp
 | 
						|
//  MNN
 | 
						|
//
 | 
						|
//  Created by MNN on 2019/09/11.
 | 
						|
//  Copyright © 2018, Alibaba Group Holding Limited
 | 
						|
//
 | 
						|
 | 
						|
#include "TRTReduce.hpp"
 | 
						|
#include <core/TensorUtils.hpp>
 | 
						|
#include "TRTBackend.hpp"
 | 
						|
 | 
						|
using namespace std;
 | 
						|
 | 
						|
namespace MNN {
 | 
						|
 | 
						|
TRTReduce::TRTReduce(Backend *b, const Op *op, const std::vector<Tensor *> &inputs,
 | 
						|
                     const std::vector<Tensor *> &outputs)
 | 
						|
    : MNN::TRTCommonExecution(b, op) {
 | 
						|
    inputDim = inputs[0]->dimensions();
 | 
						|
}
 | 
						|
 | 
						|
std::vector<ITensor *> TRTReduce::onEncode(const std::vector<ITensor *> &xOp) {
 | 
						|
#ifdef TRT_LOG
 | 
						|
    printf("TRTReduce in\n");
 | 
						|
#endif
 | 
						|
 | 
						|
    ReduceOperation operation = ReduceOperation::kSUM;
 | 
						|
    switch (mOp->main_as_ReductionParam()->operation()) {
 | 
						|
        case ReductionType_MEAN:
 | 
						|
            operation = ReduceOperation::kAVG;
 | 
						|
            break;
 | 
						|
        case ReductionType_SUM:
 | 
						|
            operation = ReduceOperation::kSUM;
 | 
						|
            break;
 | 
						|
        case ReductionType_MINIMUM:
 | 
						|
            operation = ReduceOperation::kMIN;
 | 
						|
            break;
 | 
						|
        case ReductionType_MAXIMUM:
 | 
						|
            operation = ReduceOperation::kMAX;
 | 
						|
            break;
 | 
						|
        case ReductionType_PROD:
 | 
						|
            operation = ReduceOperation::kPROD;
 | 
						|
            break;
 | 
						|
        default:
 | 
						|
            MNN_ASSERT(false);
 | 
						|
            break;
 | 
						|
    }
 | 
						|
    uint32_t mAxis = mOp->main_as_ReductionParam()->dim()->data()[0];
 | 
						|
    if (mAxis < 0) {
 | 
						|
        mAxis += inputDim;
 | 
						|
    }
 | 
						|
    MNN_ASSERT(mAxis >= 0 && mAxis < inputDim);
 | 
						|
 | 
						|
    bool keepdims = mOp->main_as_ReductionParam()->keepDims();
 | 
						|
 | 
						|
    // printf("reduce type:%d axis:%d keepdim:%d\n", mOp->main_as_ReductionParam()->operation(), mAxis, keepdims);
 | 
						|
 | 
						|
    auto Reduce_layer = mTrtBackend->getNetwork()->addReduce(*(xOp[0]), operation, 1U << mAxis, keepdims);
 | 
						|
    auto output       = Reduce_layer->getOutput(0);
 | 
						|
    return {output};
 | 
						|
}
 | 
						|
 | 
						|
TRTCreatorRegister<TypedCreator<TRTReduce>> __Reduce_op(OpType_Reduction);
 | 
						|
 | 
						|
} // namespace MNN
 |