MNN/source/backend/cpu/compute/Int8FunctionsOpt.cpp

99 lines
3.5 KiB
C++
Raw Normal View History

2019-04-17 10:49:11 +08:00
//
// Int8FunctionsOpt.cpp
// MNN
//
// Created by MNN on 2018/08/15.
// Copyright © 2018, Alibaba Group Holding Limited
//
2020-02-26 09:57:17 +08:00
#include "Int8FunctionsOpt.h"
2019-04-17 10:49:11 +08:00
#include <algorithm>
2019-12-27 22:16:57 +08:00
#include "core/Macro.h"
2020-02-26 09:57:17 +08:00
#include <math.h>
2019-04-17 10:49:11 +08:00
#ifndef MNN_USE_NEON
2020-02-26 09:57:17 +08:00
#ifndef MNN_USE_SSE
2019-04-17 10:49:11 +08:00
2020-02-26 09:57:17 +08:00
inline int8_t int32ToInt8(int data, int bias, float scale) {
float value = (float)(data + bias) * scale;
value = std::max(value, -127.0f);
value = std::min(value, 127.0f);
return static_cast<int8_t>(roundf(value));
2019-04-17 10:49:11 +08:00
}
2020-02-26 09:57:17 +08:00
void MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, const int32_t* bias,
const float* scale, size_t src_depth_quad, size_t dst_step,
size_t dst_depth_quad) {
const auto dst_step_tmp = dst_step / sizeof(int8_t);
for (int dz = 0; dz < dst_depth_quad; ++dz) {
const auto weight_dz = weight + dz * src_depth_quad * (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT);
const auto bias_dz = bias + dz * GEMM_INT8_UNIT;
const auto scale_dz = scale + dz * GEMM_INT8_UNIT;
auto dst_z = dst + dz * dst_step_tmp;
for (int w = 0; w < GEMM_INT8_DST_XUNIT; ++w) {
const auto src_x = src + w * GEMM_INT8_SRC_UNIT;
auto dst_x = dst_z + w * GEMM_INT8_UNIT;
int32_t dstTemp[4] = {0, 0, 0, 0};
2019-04-17 10:49:11 +08:00
2020-02-26 09:57:17 +08:00
for (int sz = 0; sz < src_depth_quad; ++sz) {
const auto weight_sz = weight_dz + (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT) * sz;
const auto src_z = src_x + sz * GEMM_INT8_DST_XUNIT * GEMM_INT8_SRC_UNIT;
for (int j = 0; j < GEMM_INT8_UNIT; ++j) {
const auto weight_j = weight_sz + j * GEMM_INT8_SRC_UNIT;
for (int i = 0; i < GEMM_INT8_SRC_UNIT; ++i) {
dstTemp[j] += (int32_t)src_z[i] * (int32_t)weight_j[i];
2019-04-17 10:49:11 +08:00
}
}
}
2020-02-26 09:57:17 +08:00
for (int j = 0; j < 4; ++j) {
dst_x[j] = int32ToInt8(dstTemp[j], bias_dz[j], scale_dz[j]);
}
2019-04-17 10:49:11 +08:00
}
}
}
#endif
2020-02-26 09:57:17 +08:00
void MNNFloat2Int8(const float* src, int8_t* dst, size_t sizeQuad, const float* scalep, ssize_t minValue,
ssize_t maxValue) {
2019-04-17 10:49:11 +08:00
for (int i = 0; i < sizeQuad; ++i) {
2020-02-26 09:57:17 +08:00
for (int j=0; j<4; ++j) {
int v = (int)roundf((src[4*i+j] * scalep[j]));
if (v > maxValue) {
v = maxValue;
}
if (v < minValue) {
v = minValue;
}
dst[4*i+j] = v;
2019-04-17 10:49:11 +08:00
}
}
}
static int gDepthwiseUnit = 4;
void MNNConvRunForUnitDepthWiseInt8(float* dst, const int8_t* src, const int8_t* weight, size_t fw, size_t fh,
size_t weight_y_step, size_t dilateX_step, size_t dilateY_step,
const float* scale) {
int fx, fy;
for (int i = 0; i < gDepthwiseUnit; ++i) {
dst[i] = 0;
}
auto src_z = src;
auto weight_z = weight;
for (fy = 0; fy < fh; ++fy) {
auto src_y = src_z + fy * dilateY_step;
auto weight_y = weight_z + fy * weight_y_step;
for (fx = 0; fx < fw; ++fx) {
auto weight_x = weight_y + gDepthwiseUnit * fx;
auto src_x = src_y + fx * dilateX_step;
for (int j = 0; j < gDepthwiseUnit; ++j) {
dst[j] += (float)src_x[j] * (float)weight_x[j];
}
}
}
for (int i = 0; i < gDepthwiseUnit; ++i) {
dst[i] = dst[i] * scale[i];
}
}
#endif