mirror of https://github.com/alibaba/MNN.git
				
				
				
			
		
			
				
	
	
		
			241 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			C++
		
	
	
	
			
		
		
	
	
			241 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			C++
		
	
	
	
| /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
 | |
| 
 | |
| Licensed under the Apache License, Version 2.0 (the "License");
 | |
| you may not use this file except in compliance with the License.
 | |
| You may obtain a copy of the License at
 | |
| 
 | |
|     http://www.apache.org/licenses/LICENSE-2.0
 | |
| 
 | |
| Unless required by applicable law or agreed to in writing, software
 | |
| distributed under the License is distributed on an "AS IS" BASIS,
 | |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| See the License for the specific language governing permissions and
 | |
| limitations under the License.
 | |
| ==============================================================================*/
 | |
| #ifdef MNN_SUPPORT_DEPRECATED_OP
 | |
| 
 | |
| #include "backend/cpu/compute/OptimizedComputer.hpp"
 | |
| #include <string.h>
 | |
| #include "core/Macro.h"
 | |
| #ifdef MNN_USE_NEON
 | |
| #include <arm_neon.h>
 | |
| #endif
 | |
| 
 | |
| namespace MNN {
 | |
| namespace Optimized {
 | |
| 
 | |
| // avgpooling
 | |
| void AveragePool(const uint8_t* input_data, const std::vector<int>& input_dims, int stride_width, int stride_height,
 | |
|                  int pad_width, int pad_height, int filter_width, int filter_height, int mOutputActivationMin,
 | |
|                  int mOutputActivationMax, uint8_t* output_data, const std::vector<int>& output_dims) {
 | |
|     MNN_ASSERT(mOutputActivationMin < mOutputActivationMax);
 | |
|     MNN_ASSERT(input_dims.at(0) == output_dims.at(0));
 | |
|     MNN_ASSERT(input_dims.at(3) == output_dims.at(3));
 | |
|     const int inputBatches  = input_dims.at(0);
 | |
|     const int inputChannels = input_dims.at(3);
 | |
|     const int inputHeight   = input_dims.at(1);
 | |
|     const int inputWidth    = input_dims.at(2);
 | |
|     const int outputHeight  = output_dims.at(1);
 | |
|     const int outputWidth   = output_dims.at(2);
 | |
| 
 | |
| #define UNIT 4
 | |
|     const int inputChannelUnits = UP_DIV(inputChannels, UNIT);
 | |
|     const int inputChannelRound = ROUND_UP(inputChannels, UNIT);
 | |
| 
 | |
|     for (int batch = 0; batch < inputBatches; ++batch) {
 | |
|         for (int out_y = 0; out_y < outputHeight; ++out_y) {
 | |
|             for (int out_x = 0; out_x < outputWidth; ++out_x) {
 | |
|                 const int in_x_origin    = (out_x * stride_width) - pad_width;
 | |
|                 const int in_y_origin    = (out_y * stride_height) - pad_height;
 | |
|                 const int filter_x_start = std::max(0, -in_x_origin);
 | |
|                 const int filter_x_end   = std::min(filter_width, inputWidth - in_x_origin);
 | |
|                 const int filter_y_start = std::max(0, -in_y_origin);
 | |
|                 const int filter_y_end   = std::min(filter_height, inputHeight - in_y_origin);
 | |
|                 const int filter_count   = (filter_x_end - filter_x_start) * (filter_y_end - filter_y_start);
 | |
|                 uint8_t* output_ptr      = output_data + batch * outputHeight * outputWidth * inputChannelRound +
 | |
|                                       out_y * outputWidth * UNIT + out_x * UNIT;
 | |
| #ifdef MNN_USE_NEON
 | |
|                 uint16_t result_sub = filter_count / 2;
 | |
|                 uint16x4_t min_vec  = vdup_n_u16(mOutputActivationMin);
 | |
|                 uint16x4_t max_vec  = vdup_n_u16(mOutputActivationMax);
 | |
|                 uint16x8_t acc_reg;
 | |
|                 uint16_t acc[UNIT * 2];
 | |
|                 const uint8_t* input_ptr = input_data + batch * inputHeight * inputWidth * inputChannelRound +
 | |
|                                            in_y_origin * inputWidth * UNIT + in_x_origin * UNIT;
 | |
| 
 | |
|                 for (int channel = 0; channel < inputChannelUnits; channel++) {
 | |
|                     memset(acc, 0, UNIT * 2 * sizeof(acc[0]));
 | |
|                     for (int fy = filter_y_start; fy < filter_y_end; fy++) {
 | |
|                         int fx  = filter_x_start;
 | |
|                         acc_reg = vld1q_u16(acc);
 | |
|                         for (; fx < filter_x_end - 2; fx += 2) {
 | |
|                             const uint8_t* input_cur_ptr = input_ptr + channel * inputHeight * inputWidth * UNIT +
 | |
|                                                            fy * inputWidth * UNIT + fx * UNIT;
 | |
|                             uint8x8_t input_reg = vld1_u8(input_cur_ptr);
 | |
|                             acc_reg             = vaddw_u8(acc_reg, input_reg);
 | |
|                         }
 | |
|                         vst1_u16(acc, vadd_u16(vget_high_u16(acc_reg), vget_low_u16(acc_reg)));
 | |
|                         for (; fx < filter_x_end; fx++) {
 | |
|                             const uint8_t* input_cur_ptr = input_ptr + channel * inputHeight * inputWidth * UNIT +
 | |
|                                                            fy * inputWidth * UNIT + fx * UNIT;
 | |
|                             for (int c = 0; c < UNIT; c++) {
 | |
|                                 acc[c] += input_cur_ptr[c];
 | |
|                             }
 | |
|                         }
 | |
|                     }
 | |
|                     uint8_t* output_cur_ptr = output_ptr + channel * outputHeight * outputWidth * UNIT;
 | |
|                     uint16x4_t a            = vdup_n_u16(0);
 | |
|                     for (int c = 0; c < UNIT; c++) {
 | |
|                         a[c] = (acc[c] + result_sub) / filter_count;
 | |
|                     }
 | |
|                     a                 = vmin_u16(a, max_vec);
 | |
|                     a                 = vmax_u16(a, min_vec);
 | |
|                     output_cur_ptr[0] = static_cast<uint8_t>(a[0]);
 | |
|                     output_cur_ptr[1] = static_cast<uint8_t>(a[1]);
 | |
|                     output_cur_ptr[2] = static_cast<uint8_t>(a[2]);
 | |
|                     output_cur_ptr[3] = static_cast<uint8_t>(a[3]);
 | |
|                 }
 | |
| #else
 | |
|                 uint16_t acc[UNIT];
 | |
|                 const uint8_t* input_ptr = input_data + batch * inputHeight * inputWidth * inputChannelRound +
 | |
|                                            in_y_origin * inputWidth * UNIT + in_x_origin * UNIT;
 | |
| 
 | |
|                 for (int channel = 0; channel < inputChannelUnits; channel++) {
 | |
|                     memset(acc, 0, UNIT * sizeof(acc[0]));
 | |
|                     for (int fy = filter_y_start; fy < filter_y_end; fy++) {
 | |
|                         for (int fx = filter_x_start; fx < filter_x_end; fx++) {
 | |
|                             const uint8_t* input_cur_ptr = input_ptr + channel * inputHeight * inputWidth * UNIT +
 | |
|                                                            fy * inputWidth * UNIT + fx * UNIT;
 | |
|                             for (int c = 0; c < UNIT; c++) {
 | |
|                                 acc[c] += input_cur_ptr[c];
 | |
|                             }
 | |
|                         }
 | |
|                     }
 | |
|                     for (int c = 0; c < UNIT; c++) {
 | |
|                         uint16_t a = (acc[c] + filter_count / 2) / filter_count;
 | |
|                         a          = std::max<uint16_t>(a, mOutputActivationMin);
 | |
|                         a          = std::min<uint16_t>(a, mOutputActivationMax);
 | |
|                         output_ptr[channel * outputHeight * outputWidth * UNIT + c] = static_cast<uint8_t>(a);
 | |
|                     }
 | |
|                 }
 | |
| #endif
 | |
|             }
 | |
|         }
 | |
|     }
 | |
| }
 | |
| 
 | |
| void Logistic(const uint8_t* input_data, const std::vector<int>& input_dims, int32_t inputZeroPoint,
 | |
|               int32_t input_range_radius, int32_t input_multiplier, int input_left_shift, uint8_t* output_data,
 | |
|               const std::vector<int>& output_dims) {
 | |
|     int size = 1;
 | |
|     for (int i = 0; i < input_dims.size(); i++) {
 | |
|         size *= input_dims.at(i);
 | |
|     }
 | |
| 
 | |
|     int c = 0;
 | |
| 
 | |
| #ifdef MNN_USE_NEON
 | |
|     // Handle 16 values at a time
 | |
|     for (; c <= size - 16; c += 16) {
 | |
|         // Read input uint8 values, cast to int16 and subtract inputZeroPoint
 | |
|         uint8x16_t input_val_u8 = vld1q_u8(input_data + c);
 | |
|         int16x8_t input_val_centered_0 =
 | |
|             vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(input_val_u8))), vdupq_n_s16(inputZeroPoint));
 | |
|         int16x8_t input_val_centered_1 =
 | |
|             vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(input_val_u8))), vdupq_n_s16(inputZeroPoint));
 | |
| 
 | |
|         // Prepare the bit masks that we will use at the end to implement the logic
 | |
|         // that was expressed in the scalar code with branching:
 | |
|         //   if (input_val_centered < -input_range_radius) {
 | |
|         //     output_val = 0;
 | |
|         //   } else if (input_val_centered > input_range_radius) {
 | |
|         //     output_val = 255;
 | |
|         //   } else {
 | |
|         //     ...
 | |
|         uint16x8_t mask_rightclamp_0 = vcgtq_s16(input_val_centered_0, vdupq_n_s16(input_range_radius));
 | |
|         uint16x8_t mask_rightclamp_1 = vcgtq_s16(input_val_centered_1, vdupq_n_s16(input_range_radius));
 | |
|         uint16x8_t mask_leftclamp_0  = vcgeq_s16(input_val_centered_0, vdupq_n_s16(-input_range_radius));
 | |
|         uint16x8_t mask_leftclamp_1  = vcgeq_s16(input_val_centered_1, vdupq_n_s16(-input_range_radius));
 | |
|         uint8x16_t mask_rightclamp = vcombine_u8(vshrn_n_u16(mask_rightclamp_0, 8), vshrn_n_u16(mask_rightclamp_1, 8));
 | |
|         uint8x16_t mask_leftclamp  = vcombine_u8(vshrn_n_u16(mask_leftclamp_0, 8), vshrn_n_u16(mask_leftclamp_1, 8));
 | |
| 
 | |
|         // This performs what is expressed in the scalar code as
 | |
|         // const int32 input_val_rescaled =
 | |
|         //     MultiplyByQuantizedMultiplierGreaterThanOne(
 | |
|         //         input_val_centered, input_multiplier, input_left_shift);
 | |
|         int32x4_t input_val_rescaled_0 =
 | |
|             vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_0)), vdupq_n_s32(input_left_shift));
 | |
|         int32x4_t input_val_rescaled_1 =
 | |
|             vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_0)), vdupq_n_s32(input_left_shift));
 | |
|         int32x4_t input_val_rescaled_2 =
 | |
|             vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_1)), vdupq_n_s32(input_left_shift));
 | |
|         int32x4_t input_val_rescaled_3 =
 | |
|             vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_1)), vdupq_n_s32(input_left_shift));
 | |
|         input_val_rescaled_0 = vqrdmulhq_n_s32(input_val_rescaled_0, input_multiplier);
 | |
|         input_val_rescaled_1 = vqrdmulhq_n_s32(input_val_rescaled_1, input_multiplier);
 | |
|         input_val_rescaled_2 = vqrdmulhq_n_s32(input_val_rescaled_2, input_multiplier);
 | |
|         input_val_rescaled_3 = vqrdmulhq_n_s32(input_val_rescaled_3, input_multiplier);
 | |
| 
 | |
|         // Invoke gemmlowp::logistic on FixedPoint wrapping int32x4_t
 | |
|         using FixedPoint4                 = FixedPoint<int32x4_t, 4>;
 | |
|         using FixedPoint0                 = FixedPoint<int32x4_t, 0>;
 | |
|         const FixedPoint4 input_val_f4_0  = FixedPoint4::FromRaw(input_val_rescaled_0);
 | |
|         const FixedPoint4 input_val_f4_1  = FixedPoint4::FromRaw(input_val_rescaled_1);
 | |
|         const FixedPoint4 input_val_f4_2  = FixedPoint4::FromRaw(input_val_rescaled_2);
 | |
|         const FixedPoint4 input_val_f4_3  = FixedPoint4::FromRaw(input_val_rescaled_3);
 | |
|         const FixedPoint0 output_val_f0_0 = logistic(input_val_f4_0);
 | |
|         const FixedPoint0 output_val_f0_1 = logistic(input_val_f4_1);
 | |
|         const FixedPoint0 output_val_f0_2 = logistic(input_val_f4_2);
 | |
|         const FixedPoint0 output_val_f0_3 = logistic(input_val_f4_3);
 | |
| 
 | |
|         // Divide by 2^23 as in the scalar code
 | |
|         int32x4_t output_val_s32_0 = RoundingDivideByPOT(output_val_f0_0.raw(), 23);
 | |
|         int32x4_t output_val_s32_1 = RoundingDivideByPOT(output_val_f0_1.raw(), 23);
 | |
|         int32x4_t output_val_s32_2 = RoundingDivideByPOT(output_val_f0_2.raw(), 23);
 | |
|         int32x4_t output_val_s32_3 = RoundingDivideByPOT(output_val_f0_3.raw(), 23);
 | |
| 
 | |
|         // Cast output values to uint8, saturating
 | |
|         int16x8_t output_val_s16_0 = vcombine_s16(vqmovn_s32(output_val_s32_0), vqmovn_s32(output_val_s32_1));
 | |
|         int16x8_t output_val_s16_1 = vcombine_s16(vqmovn_s32(output_val_s32_2), vqmovn_s32(output_val_s32_3));
 | |
|         uint8x16_t output_val_u8   = vcombine_u8(vqmovun_s16(output_val_s16_0), vqmovun_s16(output_val_s16_1));
 | |
| 
 | |
|         // Perform the bit-masking with the bit masks computed at the beginning,
 | |
|         // see the comment there.
 | |
|         output_val_u8 = vorrq_u8(output_val_u8, mask_rightclamp);
 | |
|         output_val_u8 = vandq_u8(output_val_u8, mask_leftclamp);
 | |
| 
 | |
|         // Store back to memory
 | |
|         vst1q_u8(output_data + c, output_val_u8);
 | |
|     }
 | |
| #endif
 | |
|     // Leftover loop: handle one value at a time with scalar code.
 | |
|     for (; c < size; ++c) {
 | |
|         const uint8_t input_val_u8       = input_data[c];
 | |
|         const int32_t input_val_centered = static_cast<int32_t>(input_val_u8) - inputZeroPoint;
 | |
|         uint8_t output_val;
 | |
|         if (input_val_centered < -input_range_radius) {
 | |
|             output_val = 0;
 | |
|         } else if (input_val_centered > input_range_radius) {
 | |
|             output_val = 255;
 | |
|         } else {
 | |
|             const int32_t input_val_rescaled =
 | |
|                 MultiplyByQuantizedMultiplierGreaterThanOne(input_val_centered, input_multiplier, input_left_shift);
 | |
|             const FixedPoint<int32_t, 4> input_val_f4  = FixedPoint<int32_t, 4>::FromRaw(input_val_rescaled);
 | |
|             const FixedPoint<int32_t, 0> output_val_f0 = logistic(input_val_f4);
 | |
|             int32_t output_val_s32                     = RoundingDivideByPOT(output_val_f0.raw(), 23);
 | |
|             if (output_val_s32 == 256) {
 | |
|                 output_val_s32 = 255;
 | |
|             }
 | |
|             MNN_ASSERT(output_val_s32 >= 0);
 | |
|             MNN_ASSERT(output_val_s32 <= 255);
 | |
|             output_val = static_cast<uint8_t>(output_val_s32);
 | |
|         }
 | |
|         output_data[c] = output_val;
 | |
|     }
 | |
| }
 | |
| 
 | |
| } // namespace Optimized
 | |
| } // namespace MNN
 | |
| 
 | |
| #endif
 |