2019-04-17 10:49:11 +08:00
|
|
|
//
|
|
|
|
// Int8FunctionsOpt.cpp
|
|
|
|
// MNN
|
|
|
|
//
|
|
|
|
// Created by MNN on 2018/08/15.
|
|
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
|
|
//
|
|
|
|
|
2021-06-11 17:17:13 +08:00
|
|
|
#include <math.h>
|
2020-11-05 16:41:56 +08:00
|
|
|
#include <cstring> // for memset
|
2020-02-26 09:57:17 +08:00
|
|
|
#include "Int8FunctionsOpt.h"
|
2019-12-27 22:16:57 +08:00
|
|
|
#include "core/Macro.h"
|
2021-06-11 17:17:13 +08:00
|
|
|
#include "CommonOptFunction.h"
|
2019-04-17 10:49:11 +08:00
|
|
|
|
2021-06-11 17:17:13 +08:00
|
|
|
#ifdef MNN_USE_NEON
|
|
|
|
#include <arm_neon.h>
|
2019-04-17 10:49:11 +08:00
|
|
|
|
2021-06-11 17:17:13 +08:00
|
|
|
extern "C" {
|
|
|
|
void MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad,
|
|
|
|
const QuanPostTreatParameters* post, size_t realCount);
|
|
|
|
void MNNGemmInt8AddBiasScale_16x4_Unit_FAST(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad,
|
|
|
|
const QuanPostTreatParameters* post, size_t realCount);
|
|
|
|
void MNNLineDepthWiseInt8AddBiasScaleUnit(int8_t* dst, const int8_t* src, const int8_t* weight, const QuanPostTreatParameters* parameters, size_t width,
|
|
|
|
size_t src_w_step, size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step);
|
|
|
|
#if defined(__aarch64__) && defined(MNN_USE_ARMV82) // aarch32 sdot workaround
|
|
|
|
void MNNGemmInt8AddBiasScale_ARMV82_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad,
|
|
|
|
const QuanPostTreatParameters* post, size_t realDstCount);
|
|
|
|
#endif // __aarch64__ && MNN_USE_ARMV82
|
|
|
|
}
|
|
|
|
#endif // MNN_USE_NEON
|
|
|
|
|
|
|
|
#ifndef MNN_USE_NEON
|
|
|
|
static int8_t MNNInt32ToInt8(int data, int bias, float scale, float maxValue, float minValue)
|
2020-11-05 16:41:56 +08:00
|
|
|
{
|
2020-02-26 09:57:17 +08:00
|
|
|
float value = (float)(data + bias) * scale;
|
2020-11-05 16:41:56 +08:00
|
|
|
value = ALIMAX(value, minValue);
|
|
|
|
value = ALIMIN(value, maxValue);
|
2020-02-26 09:57:17 +08:00
|
|
|
return static_cast<int8_t>(roundf(value));
|
2019-04-17 10:49:11 +08:00
|
|
|
}
|
|
|
|
|
2021-06-11 17:17:13 +08:00
|
|
|
static void MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step,
|
2020-12-10 12:47:38 +08:00
|
|
|
size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realCount) {
|
2020-02-26 09:57:17 +08:00
|
|
|
const auto dst_step_tmp = dst_step / sizeof(int8_t);
|
|
|
|
for (int dz = 0; dz < dst_depth_quad; ++dz) {
|
|
|
|
const auto weight_dz = weight + dz * src_depth_quad * (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT);
|
2020-11-05 16:41:56 +08:00
|
|
|
const auto bias_dz = post->bias + dz * GEMM_INT8_UNIT;
|
|
|
|
const auto scale_dz = post->scale + dz * GEMM_INT8_UNIT;
|
2020-02-26 09:57:17 +08:00
|
|
|
auto dst_z = dst + dz * dst_step_tmp;
|
2020-12-10 12:47:38 +08:00
|
|
|
for (int w = 0; w < realCount; ++w) {
|
2020-02-26 09:57:17 +08:00
|
|
|
const auto src_x = src + w * GEMM_INT8_SRC_UNIT;
|
|
|
|
auto dst_x = dst_z + w * GEMM_INT8_UNIT;
|
|
|
|
int32_t dstTemp[4] = {0, 0, 0, 0};
|
2019-04-17 10:49:11 +08:00
|
|
|
|
2020-02-26 09:57:17 +08:00
|
|
|
for (int sz = 0; sz < src_depth_quad; ++sz) {
|
|
|
|
const auto weight_sz = weight_dz + (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT) * sz;
|
|
|
|
const auto src_z = src_x + sz * GEMM_INT8_DST_XUNIT * GEMM_INT8_SRC_UNIT;
|
|
|
|
|
|
|
|
for (int j = 0; j < GEMM_INT8_UNIT; ++j) {
|
|
|
|
const auto weight_j = weight_sz + j * GEMM_INT8_SRC_UNIT;
|
|
|
|
for (int i = 0; i < GEMM_INT8_SRC_UNIT; ++i) {
|
|
|
|
dstTemp[j] += (int32_t)src_z[i] * (int32_t)weight_j[i];
|
2019-04-17 10:49:11 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-02-26 09:57:17 +08:00
|
|
|
for (int j = 0; j < 4; ++j) {
|
2021-06-11 17:17:13 +08:00
|
|
|
if (post != nullptr) {
|
|
|
|
dst_x[j] = MNNInt32ToInt8(dstTemp[j], bias_dz[j], scale_dz[j], post->maxValue, post->minValue);
|
|
|
|
} else {
|
2021-06-24 14:06:10 +08:00
|
|
|
((float*)dst_x)[j] = (float)(dstTemp[j] + bias_dz[j]);
|
2021-06-11 17:17:13 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
static void MNNGemmInt8AddBiasScale_16x4_Unit_FAST(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realCount) {
|
|
|
|
return MNNGemmInt8AddBiasScale_16x4_Unit(dst, src, weight, src_depth_quad, dst_step, dst_depth_quad, post, realCount);
|
|
|
|
}
|
|
|
|
|
|
|
|
static void MNNLineDepthWiseInt8AddBiasScaleUnit(int8_t* dst, const int8_t* src, const int8_t* weight, const QuanPostTreatParameters* parameters,
|
|
|
|
size_t width, size_t src_w_step, size_t fw, size_t fh, size_t dilateX_step,
|
|
|
|
size_t dilateY_step) {
|
|
|
|
auto bias_z = parameters->bias;
|
|
|
|
auto scale_z = parameters->scale;
|
|
|
|
int dx, fx, fy;
|
|
|
|
for (dx = 0; dx < width; ++dx) {
|
|
|
|
auto dst_x = dst + dx * 4;
|
|
|
|
int32_t dstInt32[4] = {0, 0, 0, 0};
|
|
|
|
const auto src_z = src + src_w_step * dx;
|
|
|
|
for (fy = 0; fy < fh; ++fy) {
|
|
|
|
const auto src_y = src_z + fy * dilateY_step;
|
|
|
|
const auto weight_y = weight + fy * fw * 4;
|
|
|
|
for (fx = 0; fx < fw; ++fx) {
|
|
|
|
const auto src_x = src_y + fx * dilateX_step;
|
|
|
|
const auto weight_x = weight_y + 4 * fx;
|
|
|
|
for (int j = 0; j < GEMM_INT8_UNIT; ++j) {
|
|
|
|
dstInt32[j] += (int32_t)src_x[j] * (int32_t)weight_x[j];
|
|
|
|
}
|
2020-02-26 09:57:17 +08:00
|
|
|
}
|
2019-04-17 10:49:11 +08:00
|
|
|
}
|
2021-06-11 17:17:13 +08:00
|
|
|
|
|
|
|
for (int i = 0; i < GEMM_INT8_UNIT; ++i) {
|
|
|
|
dst_x[i] = MNNInt32ToInt8(dstInt32[i], bias_z[i], scale_z[i], parameters->maxValue, parameters->minValue);
|
|
|
|
}
|
2019-04-17 10:49:11 +08:00
|
|
|
}
|
|
|
|
}
|
2021-06-11 17:17:13 +08:00
|
|
|
#endif
|
2021-01-06 16:29:37 +08:00
|
|
|
|
2021-06-11 17:17:13 +08:00
|
|
|
#ifndef MNN_USE_SSE
|
|
|
|
void MNNInt8FunctionInit() {
|
|
|
|
// do nothing
|
|
|
|
}
|
|
|
|
#ifndef MNN_USE_NEON
|
2020-02-26 09:57:17 +08:00
|
|
|
void MNNFloat2Int8(const float* src, int8_t* dst, size_t sizeQuad, const float* scalep, ssize_t minValue,
|
2021-01-06 16:29:37 +08:00
|
|
|
ssize_t maxValue, ssize_t zeroPoint) {
|
2019-04-17 10:49:11 +08:00
|
|
|
for (int i = 0; i < sizeQuad; ++i) {
|
2020-02-26 09:57:17 +08:00
|
|
|
for (int j=0; j<4; ++j) {
|
2021-01-06 16:29:37 +08:00
|
|
|
int v = (int)roundf(src[4*i+j] * scalep[j]) + zeroPoint;
|
2020-02-26 09:57:17 +08:00
|
|
|
if (v > maxValue) {
|
|
|
|
v = maxValue;
|
|
|
|
}
|
|
|
|
if (v < minValue) {
|
|
|
|
v = minValue;
|
|
|
|
}
|
|
|
|
dst[4*i+j] = v;
|
2019-04-17 10:49:11 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
2021-06-11 17:17:13 +08:00
|
|
|
|
2021-01-06 16:29:37 +08:00
|
|
|
void MNNInt8ScaleToFloat(float* dst, const int8_t* src, const float* scale, size_t size, ssize_t zeroPoint) {
|
2020-12-10 17:53:24 +08:00
|
|
|
for (int i = 0; i < size; ++i) {
|
|
|
|
const auto srcStart = src + i * 4;
|
|
|
|
auto dstStart = dst + i * 4;
|
|
|
|
for (int j = 0; j < 4; ++j) {
|
2021-01-06 16:29:37 +08:00
|
|
|
dstStart[j] = static_cast<float>(srcStart[j] - zeroPoint) * scale[j];
|
2020-12-10 17:53:24 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
2021-06-11 17:17:13 +08:00
|
|
|
#endif // #ifndef MNN_USE_NEON
|
|
|
|
#endif // #ifndef MNN_USE_SSE
|
2020-12-10 17:53:24 +08:00
|
|
|
|
2021-06-11 17:17:13 +08:00
|
|
|
/* CPU without sdot */
|
|
|
|
// Assume GEMM_INT8_UNIT == 4 && GEMM_INT8_SRC_UNIT == 16
|
|
|
|
static void _fastIm2Col(int8_t* colAddr, const int8_t* inputOrigin, int8_t inputZeroPoint,
|
|
|
|
const MNN::ConvolutionCommon::Im2ColParameter* im2colParameter, size_t xIndexStart,
|
|
|
|
size_t realDstCount) {
|
|
|
|
const int col_buffer_size = im2colParameter->kernelCountUnit * GEMM_INT8_SRC_UNIT * GEMM_INT8_DST_XUNIT * sizeof(int8_t);
|
|
|
|
::memset(colAddr, inputZeroPoint, col_buffer_size); // the padding process, since per-channel is removed, this is all right
|
2021-01-06 16:29:37 +08:00
|
|
|
|
2021-06-11 17:17:13 +08:00
|
|
|
const int icDiv8 = im2colParameter->icDiv4 / 2;
|
|
|
|
const int srcZStep = im2colParameter->iw * im2colParameter->ih * GEMM_INT8_UNIT;
|
|
|
|
inputOrigin += xIndexStart * GEMM_INT8_UNIT;
|
|
|
|
for (int i = 0; i < realDstCount; ++i) {
|
|
|
|
auto colAddrI = colAddr + GEMM_INT8_SRC_UNIT * i;
|
|
|
|
auto inputK = inputOrigin + GEMM_INT8_UNIT * i;
|
|
|
|
for (int sz = 0; sz < icDiv8; ++sz) {
|
|
|
|
auto inputZ0 = inputK + srcZStep * (2 * sz + 0);
|
|
|
|
auto inputZ1 = inputK + srcZStep * (2 * sz + 1);
|
|
|
|
const int indexOutside = sz / 2;
|
|
|
|
const int indexInsize = sz % 2;
|
2021-01-06 16:29:37 +08:00
|
|
|
|
2021-06-11 17:17:13 +08:00
|
|
|
auto dstK0 = colAddrI + (indexOutside * GEMM_INT8_DST_XUNIT * 2 + indexInsize) * (2 * GEMM_INT8_UNIT);
|
|
|
|
auto dstK1 = dstK0 + GEMM_INT8_UNIT;
|
|
|
|
*((int32_t*)dstK0) = *((int32_t*)inputZ0);
|
|
|
|
*((int32_t*)dstK1) = *((int32_t*)inputZ1);
|
2021-01-06 16:29:37 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2021-06-11 17:17:13 +08:00
|
|
|
static void _im2colCommonZ1(int8_t* colAddr, const int8_t* inputOrigin, int8_t inputZeroPoint,
|
|
|
|
const MNN::ConvolutionCommon::Im2ColParameter* im2colParameter, size_t xIndexStart,
|
|
|
|
size_t realDstCount) {
|
|
|
|
int col_buffer_size = im2colParameter->kernelCountUnit * GEMM_INT8_DST_XUNIT * GEMM_INT8_SRC_UNIT * sizeof(int8_t);
|
|
|
|
::memset(colAddr, inputZeroPoint, col_buffer_size); // the padding process, since per-channel is removed, this is all right
|
2020-12-10 11:28:55 +08:00
|
|
|
|
2021-06-11 17:17:13 +08:00
|
|
|
auto ih = im2colParameter->ih;
|
|
|
|
auto iw = im2colParameter->iw;
|
|
|
|
auto kh = im2colParameter->kernelY;
|
|
|
|
auto kw = im2colParameter->kernelX;
|
|
|
|
auto dilateX = im2colParameter->dilateX;
|
|
|
|
auto dilateY = im2colParameter->dilateY;
|
|
|
|
auto srcYStep = im2colParameter->srcYStep;
|
|
|
|
constexpr int dstXStepInt32 = GEMM_INT8_SRC_UNIT * GEMM_INT8_DST_XUNIT / sizeof(int32_t);
|
|
|
|
for (int i = 0; i < realDstCount; ++i) {
|
|
|
|
int xIndex = (int)xIndexStart + i;
|
|
|
|
int ox = xIndex % im2colParameter->ow;
|
|
|
|
int oy = xIndex / im2colParameter->ow;
|
|
|
|
|
|
|
|
int sx = ox * im2colParameter->strideX - im2colParameter->padX;
|
|
|
|
int sy = oy * im2colParameter->strideY - im2colParameter->padY;
|
|
|
|
|
|
|
|
int sfy = ALIMAX(0, (UP_DIV(-sy, im2colParameter->dilateY)));
|
|
|
|
int efy = ALIMIN(kh, UP_DIV(ih - sy, im2colParameter->dilateY));
|
|
|
|
int sfx = ALIMAX(0, (UP_DIV(-sx, im2colParameter->dilateX)));
|
|
|
|
int efx = ALIMIN(kw, UP_DIV(iw - sx, im2colParameter->dilateX));
|
|
|
|
int fyC = efy - sfy;
|
|
|
|
int fxC = efx - sfx;
|
|
|
|
|
|
|
|
auto colAddrI = colAddr + GEMM_INT8_SRC_UNIT * i;
|
|
|
|
|
|
|
|
auto inputOffset = inputOrigin + (sy + sfy * dilateY) * srcYStep + (sx + sfx * dilateX) * GEMM_INT8_UNIT;
|
|
|
|
auto indexOffset = sfy * kw + sfx;
|
|
|
|
for (int fy = 0; fy < fyC; ++fy) {
|
|
|
|
for (int fx = 0; fx < fxC; ++fx) {
|
|
|
|
auto inputK = inputOffset + fy * dilateY * srcYStep + fx * dilateX * GEMM_INT8_UNIT;
|
|
|
|
auto indexStart = indexOffset + fy * kw + fx;
|
|
|
|
auto indexInside = indexStart % 4;
|
|
|
|
auto indexOutside = indexStart / 4;
|
|
|
|
auto dstK0 = (int32_t*)colAddrI + indexOutside * dstXStepInt32 + indexInside;
|
|
|
|
dstK0[0] = *((int32_t*)inputK);
|
2019-04-17 10:49:11 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
2020-03-22 20:16:29 +08:00
|
|
|
|
2021-06-11 17:17:13 +08:00
|
|
|
static void _im2colCommon(int8_t* colAddr, const int8_t* inputOrigin, int8_t inputZeroPoint,
|
|
|
|
const MNN::ConvolutionCommon::Im2ColParameter* im2colParameter, size_t xIndexStart,
|
|
|
|
size_t realDstCount) {
|
|
|
|
const int col_buffer_size = im2colParameter->kernelCountUnit * GEMM_INT8_DST_XUNIT * GEMM_INT8_SRC_UNIT * sizeof(int8_t);
|
|
|
|
::memset(colAddr, inputZeroPoint, col_buffer_size); // the padding process, since per-channel is removed, this is all right
|
2020-03-22 20:16:29 +08:00
|
|
|
|
2021-06-11 17:17:13 +08:00
|
|
|
auto ih = im2colParameter->ih;
|
|
|
|
auto iw = im2colParameter->iw;
|
|
|
|
auto kh = im2colParameter->kernelY;
|
|
|
|
auto kw = im2colParameter->kernelX;
|
|
|
|
auto dilateX = im2colParameter->dilateX;
|
|
|
|
auto dilateY = im2colParameter->dilateY;
|
|
|
|
auto icDiv4 = im2colParameter->icDiv4;
|
|
|
|
auto srcZStep = im2colParameter->srcZStep;
|
|
|
|
auto srcYStep = im2colParameter->srcYStep;
|
|
|
|
constexpr int dstXStepInt32 = GEMM_INT8_SRC_UNIT * GEMM_INT8_DST_XUNIT / sizeof(int32_t);
|
|
|
|
for (int i = 0; i < realDstCount; ++i) {
|
|
|
|
int xIndex = (int)xIndexStart + i;
|
|
|
|
int ox = xIndex % im2colParameter->ow;
|
|
|
|
int oy = xIndex / im2colParameter->ow;
|
2020-03-22 20:16:29 +08:00
|
|
|
|
2021-06-11 17:17:13 +08:00
|
|
|
int sx = ox * im2colParameter->strideX - im2colParameter->padX;
|
|
|
|
int sy = oy * im2colParameter->strideY - im2colParameter->padY;
|
2020-03-22 20:16:29 +08:00
|
|
|
|
2021-06-11 17:17:13 +08:00
|
|
|
int sfy = ALIMAX(0, (UP_DIV(-sy, im2colParameter->dilateY)));
|
|
|
|
int efy = ALIMIN(kh, UP_DIV(ih - sy, im2colParameter->dilateY));
|
|
|
|
int sfx = ALIMAX(0, (UP_DIV(-sx, im2colParameter->dilateX)));
|
|
|
|
int efx = ALIMIN(kw, UP_DIV(iw - sx, im2colParameter->dilateX));
|
|
|
|
int fyC = efy - sfy;
|
|
|
|
int fxC = efx - sfx;
|
2020-03-22 20:16:29 +08:00
|
|
|
|
2021-06-11 17:17:13 +08:00
|
|
|
auto colAddrI = colAddr + GEMM_INT8_SRC_UNIT * i;
|
|
|
|
|
|
|
|
auto inputOffset = inputOrigin + (sy + sfy * dilateY) * srcYStep + (sx + sfx * dilateX) * GEMM_INT8_UNIT;
|
|
|
|
auto indexOffset = (sfy * kw + sfx) * icDiv4;
|
|
|
|
for (int fy = 0; fy < fyC; ++fy) {
|
|
|
|
for (int fx = 0; fx < fxC; ++fx) {
|
|
|
|
auto inputK = inputOffset + fy * dilateY * srcYStep + fx * dilateX * GEMM_INT8_UNIT;
|
|
|
|
auto indexStart = indexOffset + (fy * kw + fx) * icDiv4;
|
|
|
|
for (int sz = 0; sz < icDiv4; ++sz) {
|
|
|
|
const int yIndex = indexStart + sz;
|
|
|
|
const int ySubOutside = yIndex / GEMM_INT8_UNIT;
|
|
|
|
const int ySubInside = yIndex % GEMM_INT8_UNIT;
|
|
|
|
auto dstK0 = (int32_t*)colAddrI + ySubOutside * dstXStepInt32 + ySubInside;
|
|
|
|
dstK0[0] = *((int32_t*)inputK);
|
|
|
|
inputK += srcZStep;
|
2020-03-22 20:16:29 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2021-06-11 17:17:13 +08:00
|
|
|
static MNN::CoreInt8Functions::Im2ColFunc chooseIm2Col(const MNN::ConvolutionCommon::Im2ColParameter* im2colParam, size_t inputChannel) {
|
|
|
|
bool fastIm2Col = im2colParam->kernelX == 1 && im2colParam->kernelY == 1 && im2colParam->icDiv4 % 2 == 0 &&
|
|
|
|
im2colParam->strideX == 1 && im2colParam->strideY == 1 && im2colParam->padX == 0 &&
|
|
|
|
im2colParam->padY == 0;
|
|
|
|
int ih = im2colParam->ih, iw = im2colParam->iw;
|
|
|
|
fastIm2Col &= (im2colParam->srcYStep == iw * GEMM_INT8_UNIT && im2colParam->srcZStep == ih * iw * GEMM_INT8_UNIT);
|
|
|
|
if (fastIm2Col) {
|
|
|
|
return _fastIm2Col;
|
|
|
|
} else if (inputChannel <= 4) {
|
|
|
|
return _im2colCommonZ1;
|
|
|
|
} else {
|
|
|
|
return _im2colCommon;
|
|
|
|
}
|
|
|
|
}
|
2020-11-05 16:41:56 +08:00
|
|
|
|
2021-06-11 17:17:13 +08:00
|
|
|
static void MNNGetGemmUnit(int* UNIT, int* SRC_UNIT, int* DST_XUNIT) {
|
|
|
|
*UNIT = GEMM_INT8_UNIT;
|
|
|
|
*SRC_UNIT = GEMM_INT8_SRC_UNIT;
|
|
|
|
*DST_XUNIT = GEMM_INT8_DST_XUNIT;
|
|
|
|
}
|
|
|
|
#undef GEMM_INT8_UNIT
|
|
|
|
#undef GEMM_INT8_SRC_UNIT
|
|
|
|
#undef GEMM_INT8_DST_XUNIT
|
|
|
|
/* End */
|
2020-11-05 16:41:56 +08:00
|
|
|
|
2021-06-11 17:17:13 +08:00
|
|
|
/* CPU with sdot */
|
|
|
|
#define GEMM_INT8_UNIT 4
|
|
|
|
#define GEMM_INT8_SRC_UNIT 4
|
|
|
|
|
|
|
|
#ifdef __aarch64__
|
|
|
|
#define GEMM_INT8_DST_XUNIT 12
|
2020-11-05 16:41:56 +08:00
|
|
|
#else
|
2021-06-11 17:17:13 +08:00
|
|
|
#define GEMM_INT8_DST_XUNIT 8
|
2020-11-05 16:41:56 +08:00
|
|
|
#endif
|
|
|
|
|
2021-06-11 17:17:13 +08:00
|
|
|
static void _im2colCommonSdot(int8_t* colAddr, const int8_t* src, int8_t inputZeroPoint,
|
|
|
|
const MNN::ConvolutionCommon::Im2ColParameter* im2colParameter, size_t xIndexStart,
|
|
|
|
size_t realDstCount) {
|
|
|
|
const int colBufferSize = im2colParameter->kernelCountUnit * GEMM_INT8_DST_XUNIT * GEMM_INT8_SRC_UNIT * sizeof(int8_t);
|
|
|
|
memset(colAddr, inputZeroPoint, colBufferSize);
|
|
|
|
auto ih = im2colParameter->ih;
|
|
|
|
auto iw = im2colParameter->iw;
|
|
|
|
// auto oh = im2colParameter->oh;
|
|
|
|
auto ow = im2colParameter->ow;
|
|
|
|
auto kh = im2colParameter->kernelY;
|
|
|
|
auto kw = im2colParameter->kernelX;
|
|
|
|
auto dilateX = im2colParameter->dilateX;
|
|
|
|
auto dilateY = im2colParameter->dilateY;
|
|
|
|
auto icDiv4 = im2colParameter->icDiv4;
|
|
|
|
auto srcChannleStride = im2colParameter->srcZStep;
|
|
|
|
auto srcYStep = im2colParameter->srcYStep;
|
|
|
|
constexpr int dstXStepInt32 = GEMM_INT8_UNIT * GEMM_INT8_DST_XUNIT / sizeof(int32_t);
|
|
|
|
|
|
|
|
for (int i = 0; i < realDstCount; ++i) {
|
|
|
|
int xIndex = (int)xIndexStart + i;
|
|
|
|
int ox = xIndex % ow;
|
|
|
|
int oy = xIndex / ow;
|
|
|
|
int sx = ox * im2colParameter->strideX - im2colParameter->padX;
|
|
|
|
int sy = oy * im2colParameter->strideY - im2colParameter->padY;
|
|
|
|
int sfy = ALIMAX(0, (UP_DIV(-sy, im2colParameter->dilateY)));
|
|
|
|
int efy = ALIMIN(kh, UP_DIV(ih - sy, im2colParameter->dilateY));
|
|
|
|
int sfx = ALIMAX(0, (UP_DIV(-sx, im2colParameter->dilateX)));
|
|
|
|
int efx = ALIMIN(kw, UP_DIV(iw - sx, im2colParameter->dilateX));
|
|
|
|
int fyC = efy - sfy;
|
|
|
|
int fxC = efx - sfx;
|
|
|
|
|
|
|
|
auto colAddrI = colAddr + GEMM_INT8_UNIT * i;
|
|
|
|
auto inputOffset = src + (sy + sfy * dilateY) * srcYStep + (sx + sfx * dilateX) * GEMM_INT8_UNIT;
|
|
|
|
auto indexOffset = (sfy * kw + sfx) * icDiv4;
|
|
|
|
|
|
|
|
for (int fy = 0; fy < fyC; ++fy) {
|
|
|
|
for (int fx = 0; fx < fxC; ++fx) {
|
|
|
|
auto inputK = inputOffset + fy * dilateY * srcYStep + fx * dilateX * GEMM_INT8_UNIT;
|
|
|
|
auto indexStart = (indexOffset + (fy * kw + fx) * icDiv4) * dstXStepInt32;
|
|
|
|
for (int sz = 0; sz < icDiv4; ++sz) {
|
|
|
|
auto dstK0 = (int32_t*)colAddrI + indexStart + sz * dstXStepInt32;
|
|
|
|
dstK0[0] = *((int32_t*)inputK);
|
|
|
|
inputK += srcChannleStride;
|
|
|
|
}
|
2020-11-05 16:41:56 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2021-06-11 17:17:13 +08:00
|
|
|
static void _fastIm2ColSdot(int8_t* colAddr, const int8_t* inputOrigin, int8_t inputZeroPoint,
|
|
|
|
const MNN::ConvolutionCommon::Im2ColParameter* im2colParameter, size_t xIndexStart,
|
|
|
|
size_t realDstCount) {
|
|
|
|
const int col_buffer_size = im2colParameter->kernelCountUnit * GEMM_INT8_DST_XUNIT * GEMM_INT8_SRC_UNIT * sizeof(int8_t);
|
|
|
|
::memset(colAddr, inputZeroPoint, col_buffer_size);
|
|
|
|
const int icDiv4 = im2colParameter->icDiv4;
|
|
|
|
const int srcZStep = im2colParameter->iw * im2colParameter->ih * GEMM_INT8_UNIT;
|
|
|
|
inputOrigin += xIndexStart * GEMM_INT8_UNIT;
|
|
|
|
for (int i = 0; i < realDstCount; ++i) {
|
|
|
|
auto colAddrI = colAddr + GEMM_INT8_UNIT * i;
|
|
|
|
auto inputK = inputOrigin + GEMM_INT8_UNIT * i;
|
|
|
|
for (int sz = 0; sz < icDiv4; ++sz) {
|
|
|
|
auto inputZ0 = inputK + srcZStep * sz;
|
|
|
|
auto dstK0 = colAddrI + sz * GEMM_INT8_UNIT * GEMM_INT8_DST_XUNIT;
|
|
|
|
*((int32_t*)dstK0) = *((int32_t*)inputZ0);
|
2020-11-05 16:41:56 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2021-06-11 17:17:13 +08:00
|
|
|
static MNN::CoreInt8Functions::Im2ColFunc chooseIm2ColSdot(const MNN::ConvolutionCommon::Im2ColParameter* im2colParam, size_t inputChannel) {
|
|
|
|
bool fastIm2Col = im2colParam->kernelX == 1 && im2colParam->kernelY == 1 && im2colParam->icDiv4 % 2 == 0 &&
|
|
|
|
im2colParam->strideX == 1 && im2colParam->strideY == 1 && im2colParam->padX == 0 &&
|
|
|
|
im2colParam->padY == 0;
|
|
|
|
int ih = im2colParam->ih, iw = im2colParam->iw;
|
|
|
|
fastIm2Col &= (im2colParam->srcYStep == iw * GEMM_INT8_UNIT && im2colParam->srcZStep == ih * iw * GEMM_INT8_UNIT);
|
|
|
|
if (fastIm2Col) {
|
|
|
|
return _fastIm2ColSdot;
|
|
|
|
} else {
|
|
|
|
return _im2colCommonSdot;
|
2020-11-05 16:41:56 +08:00
|
|
|
}
|
|
|
|
}
|
2021-01-06 16:29:37 +08:00
|
|
|
|
2021-06-11 17:17:13 +08:00
|
|
|
static void MNNGetGemmUnitSdot(int* UNIT, int* SRC_UNIT, int* DST_XUNIT) {
|
|
|
|
*UNIT = GEMM_INT8_UNIT;
|
|
|
|
*SRC_UNIT = GEMM_INT8_SRC_UNIT;
|
|
|
|
*DST_XUNIT = GEMM_INT8_DST_XUNIT;
|
|
|
|
}
|
2021-01-06 16:29:37 +08:00
|
|
|
|
2021-06-11 17:17:13 +08:00
|
|
|
/* End */
|
|
|
|
#undef GEMM_INT8_UNIT
|
|
|
|
#undef GEMM_INT8_SRC_UNIT
|
|
|
|
#undef GEMM_INT8_DST_XUNIT
|
|
|
|
|
|
|
|
namespace MNN {
|
|
|
|
|
|
|
|
static CoreInt8Functions* gCoreFunc = nullptr;
|
|
|
|
|
|
|
|
void MNNCoreInt8FunctionInit() {
|
|
|
|
/* CoreInt8Functions without sdot */
|
|
|
|
gCoreFunc = new CoreInt8Functions;
|
|
|
|
// MatMul
|
|
|
|
gCoreFunc->Int8GemmKernel = MNNGemmInt8AddBiasScale_16x4_Unit;
|
|
|
|
gCoreFunc->Int8GemmKernelFast = MNNGemmInt8AddBiasScale_16x4_Unit_FAST;
|
|
|
|
gCoreFunc->MNNGetGemmUnit = MNNGetGemmUnit;
|
|
|
|
// Im2Col
|
|
|
|
gCoreFunc->chooseIm2Col = chooseIm2Col;
|
|
|
|
// conv depthwise
|
|
|
|
gCoreFunc->ConvDepthwiseLineInt8 = MNNLineDepthWiseInt8AddBiasScaleUnit;
|
|
|
|
|
|
|
|
#if defined(__aarch64__) && defined(MNN_USE_ARMV82)
|
|
|
|
auto core = MNNGetCoreFunctions();
|
|
|
|
if (core->supportSDot) {
|
|
|
|
// MatMul
|
|
|
|
gCoreFunc->Int8GemmKernel = MNNGemmInt8AddBiasScale_ARMV82_Unit;
|
|
|
|
gCoreFunc->Int8GemmKernelFast = MNNGemmInt8AddBiasScale_ARMV82_Unit;
|
|
|
|
gCoreFunc->MNNGetGemmUnit = MNNGetGemmUnitSdot;
|
|
|
|
// Im2Col
|
|
|
|
gCoreFunc->chooseIm2Col = chooseIm2ColSdot;
|
2021-01-06 16:29:37 +08:00
|
|
|
}
|
|
|
|
#endif
|
2021-06-11 17:17:13 +08:00
|
|
|
MNNInt8FunctionInit();
|
|
|
|
}
|
|
|
|
CoreInt8Functions* MNNGetInt8CoreFunctions() {
|
|
|
|
return gCoreFunc;
|
|
|
|
}
|
|
|
|
};
|