mirror of https://github.com/alibaba/MNN.git
				
				
				
			
		
			
				
	
	
		
			241 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			C++
		
	
	
	
			
		
		
	
	
			241 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			C++
		
	
	
	
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
 | 
						|
 | 
						|
Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
you may not use this file except in compliance with the License.
 | 
						|
You may obtain a copy of the License at
 | 
						|
 | 
						|
    http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
 | 
						|
Unless required by applicable law or agreed to in writing, software
 | 
						|
distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
See the License for the specific language governing permissions and
 | 
						|
limitations under the License.
 | 
						|
==============================================================================*/
 | 
						|
#ifdef MNN_SUPPORT_DEPRECATED_OP
 | 
						|
 | 
						|
#include "backend/cpu/compute/OptimizedComputer.hpp"
 | 
						|
#include <string.h>
 | 
						|
#include "core/Macro.h"
 | 
						|
#ifdef MNN_USE_NEON
 | 
						|
#include <arm_neon.h>
 | 
						|
#endif
 | 
						|
 | 
						|
namespace MNN {
 | 
						|
namespace Optimized {
 | 
						|
 | 
						|
// avgpooling
 | 
						|
void AveragePool(const uint8_t* input_data, const std::vector<int>& input_dims, int stride_width, int stride_height,
 | 
						|
                 int pad_width, int pad_height, int filter_width, int filter_height, int mOutputActivationMin,
 | 
						|
                 int mOutputActivationMax, uint8_t* output_data, const std::vector<int>& output_dims) {
 | 
						|
    MNN_ASSERT(mOutputActivationMin < mOutputActivationMax);
 | 
						|
    MNN_ASSERT(input_dims.at(0) == output_dims.at(0));
 | 
						|
    MNN_ASSERT(input_dims.at(3) == output_dims.at(3));
 | 
						|
    const int inputBatches  = input_dims.at(0);
 | 
						|
    const int inputChannels = input_dims.at(3);
 | 
						|
    const int inputHeight   = input_dims.at(1);
 | 
						|
    const int inputWidth    = input_dims.at(2);
 | 
						|
    const int outputHeight  = output_dims.at(1);
 | 
						|
    const int outputWidth   = output_dims.at(2);
 | 
						|
 | 
						|
#define UNIT 4
 | 
						|
    const int inputChannelUnits = UP_DIV(inputChannels, UNIT);
 | 
						|
    const int inputChannelRound = ROUND_UP(inputChannels, UNIT);
 | 
						|
 | 
						|
    for (int batch = 0; batch < inputBatches; ++batch) {
 | 
						|
        for (int out_y = 0; out_y < outputHeight; ++out_y) {
 | 
						|
            for (int out_x = 0; out_x < outputWidth; ++out_x) {
 | 
						|
                const int in_x_origin    = (out_x * stride_width) - pad_width;
 | 
						|
                const int in_y_origin    = (out_y * stride_height) - pad_height;
 | 
						|
                const int filter_x_start = std::max(0, -in_x_origin);
 | 
						|
                const int filter_x_end   = std::min(filter_width, inputWidth - in_x_origin);
 | 
						|
                const int filter_y_start = std::max(0, -in_y_origin);
 | 
						|
                const int filter_y_end   = std::min(filter_height, inputHeight - in_y_origin);
 | 
						|
                const int filter_count   = (filter_x_end - filter_x_start) * (filter_y_end - filter_y_start);
 | 
						|
                uint8_t* output_ptr      = output_data + batch * outputHeight * outputWidth * inputChannelRound +
 | 
						|
                                      out_y * outputWidth * UNIT + out_x * UNIT;
 | 
						|
#ifdef MNN_USE_NEON
 | 
						|
                uint16_t result_sub = filter_count / 2;
 | 
						|
                uint16x4_t min_vec  = vdup_n_u16(mOutputActivationMin);
 | 
						|
                uint16x4_t max_vec  = vdup_n_u16(mOutputActivationMax);
 | 
						|
                uint16x8_t acc_reg;
 | 
						|
                uint16_t acc[UNIT * 2];
 | 
						|
                const uint8_t* input_ptr = input_data + batch * inputHeight * inputWidth * inputChannelRound +
 | 
						|
                                           in_y_origin * inputWidth * UNIT + in_x_origin * UNIT;
 | 
						|
 | 
						|
                for (int channel = 0; channel < inputChannelUnits; channel++) {
 | 
						|
                    memset(acc, 0, UNIT * 2 * sizeof(acc[0]));
 | 
						|
                    for (int fy = filter_y_start; fy < filter_y_end; fy++) {
 | 
						|
                        int fx  = filter_x_start;
 | 
						|
                        acc_reg = vld1q_u16(acc);
 | 
						|
                        for (; fx < filter_x_end - 2; fx += 2) {
 | 
						|
                            const uint8_t* input_cur_ptr = input_ptr + channel * inputHeight * inputWidth * UNIT +
 | 
						|
                                                           fy * inputWidth * UNIT + fx * UNIT;
 | 
						|
                            uint8x8_t input_reg = vld1_u8(input_cur_ptr);
 | 
						|
                            acc_reg             = vaddw_u8(acc_reg, input_reg);
 | 
						|
                        }
 | 
						|
                        vst1_u16(acc, vadd_u16(vget_high_u16(acc_reg), vget_low_u16(acc_reg)));
 | 
						|
                        for (; fx < filter_x_end; fx++) {
 | 
						|
                            const uint8_t* input_cur_ptr = input_ptr + channel * inputHeight * inputWidth * UNIT +
 | 
						|
                                                           fy * inputWidth * UNIT + fx * UNIT;
 | 
						|
                            for (int c = 0; c < UNIT; c++) {
 | 
						|
                                acc[c] += input_cur_ptr[c];
 | 
						|
                            }
 | 
						|
                        }
 | 
						|
                    }
 | 
						|
                    uint8_t* output_cur_ptr = output_ptr + channel * outputHeight * outputWidth * UNIT;
 | 
						|
                    uint16x4_t a            = vdup_n_u16(0);
 | 
						|
                    for (int c = 0; c < UNIT; c++) {
 | 
						|
                        a[c] = (acc[c] + result_sub) / filter_count;
 | 
						|
                    }
 | 
						|
                    a                 = vmin_u16(a, max_vec);
 | 
						|
                    a                 = vmax_u16(a, min_vec);
 | 
						|
                    output_cur_ptr[0] = static_cast<uint8_t>(a[0]);
 | 
						|
                    output_cur_ptr[1] = static_cast<uint8_t>(a[1]);
 | 
						|
                    output_cur_ptr[2] = static_cast<uint8_t>(a[2]);
 | 
						|
                    output_cur_ptr[3] = static_cast<uint8_t>(a[3]);
 | 
						|
                }
 | 
						|
#else
 | 
						|
                uint16_t acc[UNIT];
 | 
						|
                const uint8_t* input_ptr = input_data + batch * inputHeight * inputWidth * inputChannelRound +
 | 
						|
                                           in_y_origin * inputWidth * UNIT + in_x_origin * UNIT;
 | 
						|
 | 
						|
                for (int channel = 0; channel < inputChannelUnits; channel++) {
 | 
						|
                    memset(acc, 0, UNIT * sizeof(acc[0]));
 | 
						|
                    for (int fy = filter_y_start; fy < filter_y_end; fy++) {
 | 
						|
                        for (int fx = filter_x_start; fx < filter_x_end; fx++) {
 | 
						|
                            const uint8_t* input_cur_ptr = input_ptr + channel * inputHeight * inputWidth * UNIT +
 | 
						|
                                                           fy * inputWidth * UNIT + fx * UNIT;
 | 
						|
                            for (int c = 0; c < UNIT; c++) {
 | 
						|
                                acc[c] += input_cur_ptr[c];
 | 
						|
                            }
 | 
						|
                        }
 | 
						|
                    }
 | 
						|
                    for (int c = 0; c < UNIT; c++) {
 | 
						|
                        uint16_t a = (acc[c] + filter_count / 2) / filter_count;
 | 
						|
                        a          = std::max<uint16_t>(a, mOutputActivationMin);
 | 
						|
                        a          = std::min<uint16_t>(a, mOutputActivationMax);
 | 
						|
                        output_ptr[channel * outputHeight * outputWidth * UNIT + c] = static_cast<uint8_t>(a);
 | 
						|
                    }
 | 
						|
                }
 | 
						|
#endif
 | 
						|
            }
 | 
						|
        }
 | 
						|
    }
 | 
						|
}
 | 
						|
 | 
						|
void Logistic(const uint8_t* input_data, const std::vector<int>& input_dims, int32_t inputZeroPoint,
 | 
						|
              int32_t input_range_radius, int32_t input_multiplier, int input_left_shift, uint8_t* output_data,
 | 
						|
              const std::vector<int>& output_dims) {
 | 
						|
    int size = 1;
 | 
						|
    for (int i = 0; i < input_dims.size(); i++) {
 | 
						|
        size *= input_dims.at(i);
 | 
						|
    }
 | 
						|
 | 
						|
    int c = 0;
 | 
						|
 | 
						|
#ifdef MNN_USE_NEON
 | 
						|
    // Handle 16 values at a time
 | 
						|
    for (; c <= size - 16; c += 16) {
 | 
						|
        // Read input uint8 values, cast to int16 and subtract inputZeroPoint
 | 
						|
        uint8x16_t input_val_u8 = vld1q_u8(input_data + c);
 | 
						|
        int16x8_t input_val_centered_0 =
 | 
						|
            vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(input_val_u8))), vdupq_n_s16(inputZeroPoint));
 | 
						|
        int16x8_t input_val_centered_1 =
 | 
						|
            vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(input_val_u8))), vdupq_n_s16(inputZeroPoint));
 | 
						|
 | 
						|
        // Prepare the bit masks that we will use at the end to implement the logic
 | 
						|
        // that was expressed in the scalar code with branching:
 | 
						|
        //   if (input_val_centered < -input_range_radius) {
 | 
						|
        //     output_val = 0;
 | 
						|
        //   } else if (input_val_centered > input_range_radius) {
 | 
						|
        //     output_val = 255;
 | 
						|
        //   } else {
 | 
						|
        //     ...
 | 
						|
        uint16x8_t mask_rightclamp_0 = vcgtq_s16(input_val_centered_0, vdupq_n_s16(input_range_radius));
 | 
						|
        uint16x8_t mask_rightclamp_1 = vcgtq_s16(input_val_centered_1, vdupq_n_s16(input_range_radius));
 | 
						|
        uint16x8_t mask_leftclamp_0  = vcgeq_s16(input_val_centered_0, vdupq_n_s16(-input_range_radius));
 | 
						|
        uint16x8_t mask_leftclamp_1  = vcgeq_s16(input_val_centered_1, vdupq_n_s16(-input_range_radius));
 | 
						|
        uint8x16_t mask_rightclamp = vcombine_u8(vshrn_n_u16(mask_rightclamp_0, 8), vshrn_n_u16(mask_rightclamp_1, 8));
 | 
						|
        uint8x16_t mask_leftclamp  = vcombine_u8(vshrn_n_u16(mask_leftclamp_0, 8), vshrn_n_u16(mask_leftclamp_1, 8));
 | 
						|
 | 
						|
        // This performs what is expressed in the scalar code as
 | 
						|
        // const int32 input_val_rescaled =
 | 
						|
        //     MultiplyByQuantizedMultiplierGreaterThanOne(
 | 
						|
        //         input_val_centered, input_multiplier, input_left_shift);
 | 
						|
        int32x4_t input_val_rescaled_0 =
 | 
						|
            vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_0)), vdupq_n_s32(input_left_shift));
 | 
						|
        int32x4_t input_val_rescaled_1 =
 | 
						|
            vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_0)), vdupq_n_s32(input_left_shift));
 | 
						|
        int32x4_t input_val_rescaled_2 =
 | 
						|
            vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_1)), vdupq_n_s32(input_left_shift));
 | 
						|
        int32x4_t input_val_rescaled_3 =
 | 
						|
            vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_1)), vdupq_n_s32(input_left_shift));
 | 
						|
        input_val_rescaled_0 = vqrdmulhq_n_s32(input_val_rescaled_0, input_multiplier);
 | 
						|
        input_val_rescaled_1 = vqrdmulhq_n_s32(input_val_rescaled_1, input_multiplier);
 | 
						|
        input_val_rescaled_2 = vqrdmulhq_n_s32(input_val_rescaled_2, input_multiplier);
 | 
						|
        input_val_rescaled_3 = vqrdmulhq_n_s32(input_val_rescaled_3, input_multiplier);
 | 
						|
 | 
						|
        // Invoke gemmlowp::logistic on FixedPoint wrapping int32x4_t
 | 
						|
        using FixedPoint4                 = FixedPoint<int32x4_t, 4>;
 | 
						|
        using FixedPoint0                 = FixedPoint<int32x4_t, 0>;
 | 
						|
        const FixedPoint4 input_val_f4_0  = FixedPoint4::FromRaw(input_val_rescaled_0);
 | 
						|
        const FixedPoint4 input_val_f4_1  = FixedPoint4::FromRaw(input_val_rescaled_1);
 | 
						|
        const FixedPoint4 input_val_f4_2  = FixedPoint4::FromRaw(input_val_rescaled_2);
 | 
						|
        const FixedPoint4 input_val_f4_3  = FixedPoint4::FromRaw(input_val_rescaled_3);
 | 
						|
        const FixedPoint0 output_val_f0_0 = logistic(input_val_f4_0);
 | 
						|
        const FixedPoint0 output_val_f0_1 = logistic(input_val_f4_1);
 | 
						|
        const FixedPoint0 output_val_f0_2 = logistic(input_val_f4_2);
 | 
						|
        const FixedPoint0 output_val_f0_3 = logistic(input_val_f4_3);
 | 
						|
 | 
						|
        // Divide by 2^23 as in the scalar code
 | 
						|
        int32x4_t output_val_s32_0 = RoundingDivideByPOT(output_val_f0_0.raw(), 23);
 | 
						|
        int32x4_t output_val_s32_1 = RoundingDivideByPOT(output_val_f0_1.raw(), 23);
 | 
						|
        int32x4_t output_val_s32_2 = RoundingDivideByPOT(output_val_f0_2.raw(), 23);
 | 
						|
        int32x4_t output_val_s32_3 = RoundingDivideByPOT(output_val_f0_3.raw(), 23);
 | 
						|
 | 
						|
        // Cast output values to uint8, saturating
 | 
						|
        int16x8_t output_val_s16_0 = vcombine_s16(vqmovn_s32(output_val_s32_0), vqmovn_s32(output_val_s32_1));
 | 
						|
        int16x8_t output_val_s16_1 = vcombine_s16(vqmovn_s32(output_val_s32_2), vqmovn_s32(output_val_s32_3));
 | 
						|
        uint8x16_t output_val_u8   = vcombine_u8(vqmovun_s16(output_val_s16_0), vqmovun_s16(output_val_s16_1));
 | 
						|
 | 
						|
        // Perform the bit-masking with the bit masks computed at the beginning,
 | 
						|
        // see the comment there.
 | 
						|
        output_val_u8 = vorrq_u8(output_val_u8, mask_rightclamp);
 | 
						|
        output_val_u8 = vandq_u8(output_val_u8, mask_leftclamp);
 | 
						|
 | 
						|
        // Store back to memory
 | 
						|
        vst1q_u8(output_data + c, output_val_u8);
 | 
						|
    }
 | 
						|
#endif
 | 
						|
    // Leftover loop: handle one value at a time with scalar code.
 | 
						|
    for (; c < size; ++c) {
 | 
						|
        const uint8_t input_val_u8       = input_data[c];
 | 
						|
        const int32_t input_val_centered = static_cast<int32_t>(input_val_u8) - inputZeroPoint;
 | 
						|
        uint8_t output_val;
 | 
						|
        if (input_val_centered < -input_range_radius) {
 | 
						|
            output_val = 0;
 | 
						|
        } else if (input_val_centered > input_range_radius) {
 | 
						|
            output_val = 255;
 | 
						|
        } else {
 | 
						|
            const int32_t input_val_rescaled =
 | 
						|
                MultiplyByQuantizedMultiplierGreaterThanOne(input_val_centered, input_multiplier, input_left_shift);
 | 
						|
            const FixedPoint<int32_t, 4> input_val_f4  = FixedPoint<int32_t, 4>::FromRaw(input_val_rescaled);
 | 
						|
            const FixedPoint<int32_t, 0> output_val_f0 = logistic(input_val_f4);
 | 
						|
            int32_t output_val_s32                     = RoundingDivideByPOT(output_val_f0.raw(), 23);
 | 
						|
            if (output_val_s32 == 256) {
 | 
						|
                output_val_s32 = 255;
 | 
						|
            }
 | 
						|
            MNN_ASSERT(output_val_s32 >= 0);
 | 
						|
            MNN_ASSERT(output_val_s32 <= 255);
 | 
						|
            output_val = static_cast<uint8_t>(output_val_s32);
 | 
						|
        }
 | 
						|
        output_data[c] = output_val;
 | 
						|
    }
 | 
						|
}
 | 
						|
 | 
						|
} // namespace Optimized
 | 
						|
} // namespace MNN
 | 
						|
 | 
						|
#endif
 |