mirror of https://github.com/alibaba/MNN.git
2298 lines
107 KiB
C++
2298 lines
107 KiB
C++
//
|
|
// Int8FunctionsOpt.cpp
|
|
// MNN
|
|
//
|
|
// Created by MNN on 2018/08/15.
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
//
|
|
|
|
#include <math.h>
|
|
#include <cstring> // for memset
|
|
#include "Int8FunctionsOpt.h"
|
|
#include "core/Macro.h"
|
|
#include "common/CommonCompute.hpp"
|
|
#include "CommonOptFunction.h"
|
|
|
|
#ifdef MNN_USE_NEON
|
|
#include <arm_neon.h>
|
|
|
|
extern "C" {
|
|
void MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad,
|
|
const QuanPostTreatParameters* post, size_t realCount);
|
|
void MNNGemmInt8AddBiasScale_16x4_Unit_FAST(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad,
|
|
const QuanPostTreatParameters* post, size_t realCount);
|
|
void MNNLineDepthWiseInt8AddBiasScaleUnit(int8_t* dst, const int8_t* src, const int8_t* weight, const QuanPostTreatParameters* parameters, size_t width,
|
|
size_t src_w_step, size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step);
|
|
void MNNMaxPoolInt8(int8_t* dst, int8_t* src, size_t outputWidth, size_t inputWidth, size_t kernelx, size_t kernely, size_t stridesx);
|
|
|
|
void MNNAvgPoolInt8(int8_t* dst, int8_t* src, size_t outputWidth, size_t inputWidth, size_t kernelx, size_t kernely, size_t stridesx, ssize_t paddingx, ssize_t factor);
|
|
|
|
#if defined(__aarch64__) // aarch32 sdot workaround
|
|
void MNNGemmInt8AddBiasScale_ARMV82_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad,
|
|
const QuanPostTreatParameters* post, size_t realDstCount);
|
|
void MNNGemmInt8AddBiasScale_ARMV86_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad,
|
|
const QuanPostTreatParameters* post, size_t realDstCount);
|
|
#endif // __aarch64__
|
|
}
|
|
#endif // MNN_USE_NEON
|
|
|
|
/*
|
|
layout should be optimized for int8
|
|
source: source matrix is h x l
|
|
transpose: if false, export compressed matrix as h x l, other export as l x h.
|
|
*/
|
|
void MNNPackForSparseQuantMatMul_B(int8_t* dest, unsigned int* NNZMap, int* dataOffsetMap, int sparseBlockOC, const int8_t* source, size_t h, size_t kernelCount, size_t icCount, const int eP) {
|
|
// 1. in quant convolution, source B layout is OC x (IC * KH * KW),
|
|
// the dest layout of weight is BCSC(block compressed sparse colum) format, which is OC(!=0) x (KH*KW*IC!=0), as a canceled result, just do BCSR
|
|
// 2. IC would be switched into the last dim.
|
|
|
|
// BCSC
|
|
int columOffset = 0;
|
|
int i = 0;
|
|
auto subSource = source;
|
|
size_t l = kernelCount * icCount;
|
|
for (; i + sparseBlockOC <= h; i += sparseBlockOC) {
|
|
*NNZMap = 0;
|
|
for(int ik = 0; ik < kernelCount; ik += 1) {
|
|
auto kernelSource = subSource + ik;
|
|
for(int ic = 0; ic < icCount; ic += 1) {
|
|
if (!MNN::CommonCompute::checkAllZeros(kernelSource, l, sparseBlockOC, 1)) {
|
|
for (int ioc = 0; ioc < sparseBlockOC; ioc++) {
|
|
*dest = *(kernelSource + ioc * l);
|
|
dest++;
|
|
}
|
|
*NNZMap = *NNZMap + 1;
|
|
*dataOffsetMap = columOffset;
|
|
dataOffsetMap++;
|
|
columOffset = 0;
|
|
}
|
|
columOffset += eP;
|
|
kernelSource += kernelCount;
|
|
}
|
|
}
|
|
NNZMap++;
|
|
columOffset -= l * eP;
|
|
subSource += sparseBlockOC * l;
|
|
}
|
|
|
|
for (; i < h; i++) {
|
|
*NNZMap = 0;
|
|
for(int ik = 0; ik < kernelCount; ik += 1) {
|
|
auto kernelSource = subSource + ik;
|
|
for(int ic = 0; ic < icCount; ic += 1) {
|
|
if (*kernelSource != 0) {
|
|
*dest = *kernelSource;
|
|
dest++;
|
|
*NNZMap = *NNZMap + 1;
|
|
*dataOffsetMap = columOffset;
|
|
dataOffsetMap++;
|
|
columOffset = 0;
|
|
}
|
|
columOffset += eP;
|
|
kernelSource += kernelCount;
|
|
}
|
|
}
|
|
NNZMap++;
|
|
columOffset -= l * eP;
|
|
subSource += l;
|
|
}
|
|
|
|
*dataOffsetMap = columOffset; //
|
|
|
|
return;
|
|
}
|
|
|
|
|
|
void MNNGetSparseQuantMatMulPackMode(int* eP, int *lP, int* hP) {
|
|
#if defined(__arm__) && !defined(__aarch64__)
|
|
*eP = 8;
|
|
#else
|
|
*eP = 16;
|
|
#endif
|
|
*lP = 1;
|
|
*hP = 4;
|
|
// hp is corresponding to sparse block along right matrix colum dimension. in ramdom sparse, it is 1.
|
|
return;
|
|
}
|
|
|
|
|
|
static void MNNSparseQuantIm2col(int8_t* colAddr, const int8_t* inputOrigin, int8_t inputZeroPoint,
|
|
const MNN::ConvolutionCommon::Im2ColParameter* im2colParameter, const size_t* sparseQuantParam, size_t xIndexStart) {
|
|
auto ih = im2colParameter->ih;
|
|
auto iw = im2colParameter->iw;
|
|
auto kh = im2colParameter->kernelY;
|
|
auto kw = im2colParameter->kernelX;
|
|
auto dilateX = im2colParameter->dilateX;
|
|
auto dilateY = im2colParameter->dilateY;
|
|
auto icDiv4 = im2colParameter->icDiv4;
|
|
auto srcZStep = im2colParameter->srcZStep;
|
|
auto srcYStep = im2colParameter->srcYStep;
|
|
auto destICStride = im2colParameter->destICStride;
|
|
auto packCUnit = im2colParameter->packCUnit;
|
|
|
|
size_t eSize= sparseQuantParam[0];
|
|
size_t eP= sparseQuantParam[1];
|
|
size_t l= sparseQuantParam[3];
|
|
size_t ePx4 = eP << 2;
|
|
const int col_buffer_size = l * eP * sizeof(int8_t);
|
|
::memset(colAddr, inputZeroPoint, col_buffer_size); // the padding process, since per-channel is removed, this is all right
|
|
|
|
for (int i = 0; i < eSize; ++i) {
|
|
int xIndex = (int)xIndexStart + i;
|
|
int ox = xIndex % im2colParameter->ow;
|
|
int oy = xIndex / im2colParameter->ow;
|
|
|
|
int sx = ox * im2colParameter->strideX - im2colParameter->padX;
|
|
int sy = oy * im2colParameter->strideY - im2colParameter->padY;
|
|
|
|
int sfy = ALIMAX(0, (UP_DIV(-sy, im2colParameter->dilateY)));
|
|
int efy = ALIMIN(kh, UP_DIV(ih - sy, im2colParameter->dilateY));
|
|
int sfx = ALIMAX(0, (UP_DIV(-sx, im2colParameter->dilateX)));
|
|
int efx = ALIMIN(kw, UP_DIV(iw - sx, im2colParameter->dilateX));
|
|
int fyC = efy - sfy;
|
|
int fxC = efx - sfx;
|
|
|
|
auto inputOffset = inputOrigin + (sy + sfy * dilateY) * srcYStep + (sx + sfx * dilateX) * packCUnit; // offset in (c/4, ih, iw, 4),
|
|
auto destBase = colAddr + (sfy * kw + sfx) * destICStride + i;
|
|
for (int fy = 0; fy < fyC; ++fy) {
|
|
for (int fx = 0; fx < fxC; ++fx) {
|
|
auto inputK = inputOffset + fy * dilateY * srcYStep + fx * dilateX * packCUnit;// origin data matrix offset inside kernel
|
|
auto destWrite = destBase + (fy * kw + fx) * destICStride;
|
|
int8_t* destWrite4[4] = {
|
|
destWrite,
|
|
destWrite + eP,
|
|
destWrite + 2 * eP,
|
|
destWrite + 3 * eP
|
|
};
|
|
for (int sz = 0; sz < icDiv4; ++sz) {
|
|
// for (int ic4 = 0; ic4 < packCUnit; ic4++) {
|
|
// *destWrite = inputK[ic4];
|
|
// destWrite += eP;
|
|
// }
|
|
int8_t c4[4];
|
|
memcpy(c4, inputK, sizeof(int32_t));
|
|
*(destWrite4[0]) = c4[0];
|
|
*(destWrite4[1]) = c4[1];
|
|
*(destWrite4[2]) = c4[2];
|
|
*(destWrite4[3]) = c4[3];
|
|
|
|
destWrite4[0]+= ePx4;
|
|
destWrite4[1]+= ePx4;
|
|
destWrite4[2]+= ePx4;
|
|
destWrite4[3]+= ePx4;
|
|
inputK += srcZStep;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
#ifndef MNN_USE_NEON
|
|
|
|
void MNNPackedSparseQuantMatMulEpx1(int8_t* C, const int8_t* A, const int8_t* B, const size_t* sparseQuantParam, const QuanPostTreatParameters* post, unsigned int* NNZMap, int* dataOffsetMap) {
|
|
|
|
size_t eSize = sparseQuantParam[0];
|
|
size_t eP = sparseQuantParam[1];
|
|
size_t aStride = sparseQuantParam[2];
|
|
size_t l = sparseQuantParam[3];
|
|
size_t h = sparseQuantParam[4];
|
|
size_t cStride = sparseQuantParam[5];
|
|
|
|
const int32_t* bias = post->bias;
|
|
const float* scales = post->scale;
|
|
const int32_t maxValue = post->maxValue;
|
|
const int32_t minValue = post->minValue;
|
|
|
|
const int sparseBlockOC = 4;
|
|
const int8_t * a = A;
|
|
size_t ie = 0;
|
|
for (ie = 0; ie < eSize && eP <= eSize; ie += eP) {
|
|
const int* dataOffset = dataOffsetMap;
|
|
const int diff = *dataOffset++;
|
|
a += diff;
|
|
const int8_t * w = B;
|
|
int8_t * blockC = C + (ie << 2);
|
|
const unsigned int* nnz = NNZMap;
|
|
|
|
for (size_t ih = 0; ih < h; ih++) {
|
|
auto ihPack = ih >> 2;
|
|
auto ihSubIndex = ih & 0x03;
|
|
auto c = blockC + ihPack * cStride + ihSubIndex;
|
|
const int32_t initValue = nullptr != bias ? bias[ih] : 0;
|
|
int32_t acc0 = initValue;
|
|
int32_t acc1 = initValue;
|
|
int32_t acc2 = initValue;
|
|
int32_t acc3 = initValue;
|
|
int32_t acc4 = initValue;
|
|
int32_t acc5 = initValue;
|
|
int32_t acc6 = initValue;
|
|
int32_t acc7 = initValue;
|
|
int32_t acc8 = initValue;
|
|
int32_t acc9 = initValue;
|
|
int32_t acc10 = initValue;
|
|
int32_t acc11 = initValue;
|
|
int32_t acc12 = initValue;
|
|
int32_t acc13 = initValue;
|
|
int32_t acc14 = initValue;
|
|
int32_t acc15 = initValue;
|
|
const int lElement = *nnz++;
|
|
for (auto il = 0; il < lElement; il++) {
|
|
|
|
const int diff = *dataOffset++;
|
|
const int8_t a0 = a[0];
|
|
const int8_t a1 = a[1];
|
|
const int8_t a2 = a[2];
|
|
const int8_t a3 = a[3];
|
|
const int8_t a4 = a[4];
|
|
const int8_t a5 = a[5];
|
|
const int8_t a6 = a[6];
|
|
const int8_t a7 = a[7];
|
|
const int8_t a8 = a[8];
|
|
const int8_t a9 = a[9];
|
|
const int8_t a10 = a[10];
|
|
const int8_t a11 = a[11];
|
|
const int8_t a12 = a[12];
|
|
const int8_t a13 = a[13];
|
|
const int8_t a14 = a[14];
|
|
const int8_t a15 = a[15];
|
|
|
|
const int8_t oneW = *w++;
|
|
|
|
// MNN_PRINT("16-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:%f, a value[0-15]:", ie, a - A, w - B - 1, c - C, oneW);
|
|
// formatMatrix(a, {16});
|
|
// MNN_PRINT("\n");
|
|
a = a + diff;
|
|
acc0 += (int32_t)a0 * (int32_t)oneW;
|
|
acc1 += (int32_t)a1 * (int32_t)oneW;
|
|
acc2 += (int32_t)a2 * (int32_t)oneW;
|
|
acc3 += (int32_t)a3 * (int32_t)oneW;
|
|
acc4 += (int32_t)a4 * (int32_t)oneW;
|
|
acc5 += (int32_t)a5 * (int32_t)oneW;
|
|
acc6 += (int32_t)a6 * (int32_t)oneW;
|
|
acc7 += (int32_t)a7 * (int32_t)oneW;
|
|
acc8 += (int32_t)a8 * (int32_t)oneW;
|
|
acc9 += (int32_t)a9 * (int32_t)oneW;
|
|
acc10 += (int32_t)a10 * (int32_t)oneW;
|
|
acc11 += (int32_t)a11 * (int32_t)oneW;
|
|
acc12 += (int32_t)a12 * (int32_t)oneW;
|
|
acc13 += (int32_t)a13 * (int32_t)oneW;
|
|
acc14 += (int32_t)a14 * (int32_t)oneW;
|
|
acc15 += (int32_t)a15 * (int32_t)oneW;
|
|
}
|
|
|
|
int8_t result0; // in assemmbly code, consider reuse acc0[0-8] bit
|
|
int8_t result1;
|
|
int8_t result2;
|
|
int8_t result3;
|
|
int8_t result4;
|
|
int8_t result5;
|
|
int8_t result6;
|
|
int8_t result7;
|
|
int8_t result8;
|
|
int8_t result9;
|
|
int8_t result10;
|
|
int8_t result11;
|
|
int8_t result12;
|
|
int8_t result13;
|
|
int8_t result14;
|
|
int8_t result15;
|
|
|
|
if (scales) {
|
|
result0 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc0)), float(minValue))));
|
|
result1 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc1)), float(minValue))));
|
|
result2 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc2)), float(minValue))));
|
|
result3 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc3)), float(minValue))));
|
|
result4 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc4)), float(minValue))));
|
|
result5 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc5)), float(minValue))));
|
|
result6 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc6)), float(minValue))));
|
|
result7 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc7)), float(minValue))));
|
|
result8 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc8)), float(minValue))));
|
|
result9 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc9)), float(minValue))));
|
|
result10 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc10)), float(minValue))));
|
|
result11 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc11)), float(minValue))));
|
|
result12 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc12)), float(minValue))));
|
|
result13 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc13)), float(minValue))));
|
|
result14 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc14)), float(minValue))));
|
|
result15 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc15)), float(minValue))));
|
|
} else {
|
|
result0 = static_cast<int8_t>(std::max(std::min(maxValue, acc0), minValue));
|
|
result1 = static_cast<int8_t>(std::max(std::min(maxValue, acc1), minValue));
|
|
result2 = static_cast<int8_t>(std::max(std::min(maxValue, acc2), minValue));
|
|
result3 = static_cast<int8_t>(std::max(std::min(maxValue, acc3), minValue));
|
|
result4 = static_cast<int8_t>(std::max(std::min(maxValue, acc4), minValue));
|
|
result5 = static_cast<int8_t>(std::max(std::min(maxValue, acc5), minValue));
|
|
result6 = static_cast<int8_t>(std::max(std::min(maxValue, acc6), minValue));
|
|
result7 = static_cast<int8_t>(std::max(std::min(maxValue, acc7), minValue));
|
|
result8 = static_cast<int8_t>(std::max(std::min(maxValue, acc8), minValue));
|
|
result9 = static_cast<int8_t>(std::max(std::min(maxValue, acc9), minValue));
|
|
result10 = static_cast<int8_t>(std::max(std::min(maxValue, acc10), minValue));
|
|
result11 = static_cast<int8_t>(std::max(std::min(maxValue, acc11), minValue));
|
|
result12 = static_cast<int8_t>(std::max(std::min(maxValue, acc12), minValue));
|
|
result13 = static_cast<int8_t>(std::max(std::min(maxValue, acc13), minValue));
|
|
result14 = static_cast<int8_t>(std::max(std::min(maxValue, acc14), minValue));
|
|
result15 = static_cast<int8_t>(std::max(std::min(maxValue, acc15), minValue));
|
|
}
|
|
|
|
// how to store faster: st4 / transpose /
|
|
c[0] = result0;
|
|
c[4] = result1;
|
|
c[4 * 2] = result2;
|
|
c[4 * 3] = result3;
|
|
c[4 * 4] = result4;
|
|
c[4 * 5] = result5;
|
|
c[4 * 6] = result6;
|
|
c[4 * 7] = result7;
|
|
c[4 * 8] = result8;
|
|
c[4 * 9] = result9;
|
|
c[4 * 10] = result10;
|
|
c[4 * 11] = result11;
|
|
c[4 * 12] = result12;
|
|
c[4 * 13] = result13;
|
|
c[4 * 14] = result14;
|
|
c[4 * 15] = result15;
|
|
}
|
|
a += aStride;
|
|
}
|
|
if (eSize & 0x08) {
|
|
const int* dataOffset = dataOffsetMap;
|
|
const int diff = *dataOffset++;
|
|
// a = blockA + diff;
|
|
a += diff;
|
|
const int8_t* w = B;
|
|
int8_t* blockC = C + (ie << 2);
|
|
const unsigned int* nnz = NNZMap;
|
|
for (size_t ih = 0; ih < h; ih++) {
|
|
auto ihPack = ih >> 2;
|
|
auto ihSubIndex = ih & 0x03;
|
|
auto c = blockC + ihPack * cStride + ihSubIndex;
|
|
const int32_t initValue = nullptr != bias ? bias[ih] : 0;
|
|
int32_t acc0 = initValue;
|
|
int32_t acc1 = initValue;
|
|
int32_t acc2 = initValue;
|
|
int32_t acc3 = initValue;
|
|
int32_t acc4 = initValue;
|
|
int32_t acc5 = initValue;
|
|
int32_t acc6 = initValue;
|
|
int32_t acc7 = initValue;
|
|
|
|
const int lElement = *nnz++;
|
|
for (auto il = 0; il < lElement; il++) {
|
|
const int diff = *dataOffset++;
|
|
const int8_t a0 = a[0];
|
|
const int8_t a1 = a[1];
|
|
const int8_t a2 = a[2];
|
|
const int8_t a3 = a[3];
|
|
const int8_t a4 = a[4];
|
|
const int8_t a5 = a[5];
|
|
const int8_t a6 = a[6];
|
|
const int8_t a7 = a[7];
|
|
const int8_t oneW = *w++;
|
|
// MNN_PRINT("8-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:%d, a value[0-7]:\n", ie, a - A, w - B - 1, c - C, oneW);
|
|
// formatMatrix(a, {8});
|
|
// MNN_PRINT("\n");
|
|
a = a + diff;
|
|
acc0 += int32_t(a0) * int32_t(oneW);
|
|
acc1 += int32_t(a1) * int32_t(oneW);
|
|
acc2 += int32_t(a2) * int32_t(oneW);
|
|
acc3 += int32_t(a3) * int32_t(oneW);
|
|
acc4 += int32_t(a4) * int32_t(oneW);
|
|
acc5 += int32_t(a5) * int32_t(oneW);
|
|
acc6 += int32_t(a6) * int32_t(oneW);
|
|
acc7 += int32_t(a7) * int32_t(oneW);
|
|
}
|
|
|
|
int8_t result0;
|
|
int8_t result1;
|
|
int8_t result2;
|
|
int8_t result3;
|
|
int8_t result4;
|
|
int8_t result5;
|
|
int8_t result6;
|
|
int8_t result7;
|
|
if (scales) {
|
|
result0 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc0)), float(minValue))));
|
|
result1 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc1)), float(minValue))));
|
|
result2 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc2)), float(minValue))));
|
|
result3 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc3)), float(minValue))));
|
|
result4 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc4)), float(minValue))));
|
|
result5 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc5)), float(minValue))));
|
|
result6 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc6)), float(minValue))));
|
|
result7 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc7)), float(minValue))));
|
|
|
|
} else {
|
|
result0 = static_cast<int8_t>(std::max(std::min(maxValue, acc0), minValue));
|
|
result1 = static_cast<int8_t>(std::max(std::min(maxValue, acc1), minValue));
|
|
result2 = static_cast<int8_t>(std::max(std::min(maxValue, acc2), minValue));
|
|
result3 = static_cast<int8_t>(std::max(std::min(maxValue, acc3), minValue));
|
|
result4 = static_cast<int8_t>(std::max(std::min(maxValue, acc4), minValue));
|
|
result5 = static_cast<int8_t>(std::max(std::min(maxValue, acc5), minValue));
|
|
result6 = static_cast<int8_t>(std::max(std::min(maxValue, acc6), minValue));
|
|
result7 = static_cast<int8_t>(std::max(std::min(maxValue, acc7), minValue));
|
|
}
|
|
|
|
// how to store faster: st4 / transpose /
|
|
c[0] = result0;
|
|
c[4] = result1;
|
|
c[4 * 2] = result2;
|
|
c[4 * 3] = result3;
|
|
c[4 * 4] = result4;
|
|
c[4 * 5] = result5;
|
|
c[4 * 6] = result6;
|
|
c[4 * 7] = result7;
|
|
}
|
|
ie += 8;
|
|
a += 8;
|
|
}
|
|
if (eSize & 0x04) {
|
|
const int* dataOffset = dataOffsetMap;
|
|
const int diff = *dataOffset++;
|
|
// a = blockA + diff;
|
|
a += diff;
|
|
const int8_t* w = B;
|
|
int8_t* blockC = C + (ie << 2);
|
|
const unsigned int* nnz = NNZMap;
|
|
|
|
for (size_t ih = 0; ih < h; ih++) {
|
|
auto ihPack = ih >> 2;
|
|
auto ihSubIndex = ih & 0x03;
|
|
auto c = blockC + ihPack * cStride + ihSubIndex;
|
|
const int32_t initValue = nullptr != bias ? bias[ih] : 0;
|
|
int32_t acc0 = initValue;
|
|
int32_t acc1 = initValue;
|
|
int32_t acc2 = initValue;
|
|
int32_t acc3 = initValue;
|
|
|
|
const int lElement = *nnz++;
|
|
for (auto il = 0; il < lElement; il++) {
|
|
const int diff = *dataOffset++;
|
|
const int8_t a0 = a[0];
|
|
const int8_t a1 = a[1];
|
|
const int8_t a2 = a[2];
|
|
const int8_t a3 = a[3];
|
|
const int8_t oneW = *w++;
|
|
// MNN_PRINT("4-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:%d, a value[0-3]:\n", ie, a - A, w - B - 1, c - C, oneW);
|
|
// formatMatrix(a, {4});
|
|
// MNN_PRINT("\n");
|
|
a = a + diff;
|
|
acc0 += int32_t(a0) * int32_t(oneW);
|
|
acc1 += int32_t(a1) * int32_t(oneW);
|
|
acc2 += int32_t(a2) * int32_t(oneW);
|
|
acc3 += int32_t(a3) * int32_t(oneW);
|
|
}
|
|
|
|
int8_t result0;
|
|
int8_t result1;
|
|
int8_t result2;
|
|
int8_t result3;
|
|
if (scales) {
|
|
result0 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc0)), float(minValue))));
|
|
result1 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc1)), float(minValue))));
|
|
result2 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc2)), float(minValue))));
|
|
result3 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc3)), float(minValue))));
|
|
} else {
|
|
result0 = static_cast<int8_t>(std::max(std::min(maxValue, acc0), minValue));
|
|
result1 = static_cast<int8_t>(std::max(std::min(maxValue, acc1), minValue));
|
|
result2 = static_cast<int8_t>(std::max(std::min(maxValue, acc2), minValue));
|
|
result3 = static_cast<int8_t>(std::max(std::min(maxValue, acc3), minValue));
|
|
}
|
|
|
|
// how to store faster: st4 / transpose /
|
|
c[0] = result0;
|
|
c[4] = result1;
|
|
c[4 * 2] = result2;
|
|
c[4 * 3] = result3;
|
|
}
|
|
ie += 4;
|
|
a += 4;
|
|
}
|
|
if (eSize & 0x02) {
|
|
const int* dataOffset = dataOffsetMap;
|
|
const int diff = *dataOffset++;
|
|
// a = blockA + diff;
|
|
a += diff;
|
|
const int8_t* w = B;
|
|
int8_t* blockC = C + (ie << 2);
|
|
const unsigned int* nnz = NNZMap;
|
|
for (size_t ih = 0; ih < h; ih++) {
|
|
auto ihPack = ih >> 2;
|
|
auto ihSubIndex = ih & 0x03;
|
|
auto c = blockC + ihPack * cStride + ihSubIndex;
|
|
const int32_t initValue = nullptr != bias ? bias[ih] : 0;
|
|
int32_t acc0 = initValue;
|
|
int32_t acc1 = initValue;
|
|
|
|
const int lElement = *nnz++;
|
|
for (auto il = 0; il < lElement; il++) {
|
|
const int diff = *dataOffset++;
|
|
const int8_t a0 = a[0];
|
|
const int8_t a1 = a[1];
|
|
const int8_t oneW = *w++;
|
|
// MNN_PRINT("2-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:%d, a value[0-1]:\n", ie, a - A, w - B - 1, c - C, oneW);
|
|
// formatMatrix(a, {2});
|
|
// MNN_PRINT("\n");
|
|
a = a + diff;
|
|
acc0 += int32_t(a0) * int32_t(oneW);
|
|
acc1 += int32_t(a1) * int32_t(oneW);
|
|
}
|
|
|
|
int8_t result0;
|
|
int8_t result1;
|
|
if (scales) {
|
|
result0 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc0)), float(minValue))));
|
|
result1 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc1)), float(minValue))));
|
|
} else {
|
|
result0 = static_cast<int8_t>(std::max(std::min(maxValue, acc0), minValue));
|
|
result1 = static_cast<int8_t>(std::max(std::min(maxValue, acc1), minValue));
|
|
}
|
|
|
|
// how to store faster: st4 / transpose /
|
|
c[0] = result0;
|
|
c[4] = result1;
|
|
}
|
|
ie += 2;
|
|
a += 2;
|
|
}
|
|
if (eSize & 0x01) {
|
|
const int* dataOffset = dataOffsetMap;
|
|
const int diff = *dataOffset++;
|
|
// const float* a = blockA + diff;
|
|
a += diff;
|
|
const int8_t * w = B;
|
|
int8_t * blockC = C + (ie << 2);
|
|
const unsigned int* nnz = NNZMap;
|
|
for (size_t ih = 0; ih < h; ih++) {
|
|
auto ihPack = ih >> 2;
|
|
auto ihSubIndex = ih & 0x03;
|
|
auto c = blockC + ihPack * cStride + ihSubIndex;
|
|
const int32_t initValue = nullptr != bias ? bias[ih] : 0;
|
|
int32_t acc0 = initValue;
|
|
|
|
const int lElement = *nnz++;
|
|
for (auto il = 0; il < lElement; il++) {
|
|
const int diff = *dataOffset++;
|
|
const int8_t a0 = a[0];
|
|
const int8_t oneW = *w++;
|
|
|
|
// MNN_PRINT("1-loop: ie:%zu, a offset:%ld, c offset:%ld, w offset:%ld, w value:%d, a value[0]:\n", ie, a - A, w - B - 1, c - C, oneW);
|
|
// formatMatrix(a, {1});
|
|
// MNN_PRINT("\n");
|
|
a = a + diff;
|
|
acc0 += int32_t(a0) * int32_t(oneW);
|
|
}
|
|
int8_t result0;
|
|
if (scales) {
|
|
result0 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc0)), float(minValue))));
|
|
} else {
|
|
result0 = static_cast<int8_t>(std::max(std::min(maxValue, acc0), minValue));
|
|
}
|
|
// how to store faster: st4 / transpose /
|
|
c[0] = result0;
|
|
}
|
|
ie += 1;
|
|
// a += 1;
|
|
}
|
|
|
|
}
|
|
|
|
void MNNPackedSparseQuantMatMulEpx4(int8_t* C, const int8_t* A, const int8_t* B, const size_t* sparseQuantParam, const QuanPostTreatParameters* post, unsigned int* NNZMap, int* dataOffsetMap) {
|
|
|
|
size_t eSize = sparseQuantParam[0];
|
|
size_t eP = sparseQuantParam[1];
|
|
size_t aStride = sparseQuantParam[2];
|
|
size_t l = sparseQuantParam[3];
|
|
size_t h = sparseQuantParam[4];
|
|
size_t cStride = sparseQuantParam[5];
|
|
|
|
const int32_t* bias = post->bias;
|
|
const float* scales = post->scale;
|
|
const int32_t maxValue = post->maxValue;
|
|
const int32_t minValue = post->minValue;
|
|
|
|
const int sparseBlockOC = 4;
|
|
const int8_t * a = A;
|
|
size_t ie = 0;
|
|
for (ie = 0; ie < eSize && eP <= eSize; ie += eP) {
|
|
const int* dataOffset = dataOffsetMap;
|
|
const int diff = *dataOffset++;
|
|
a += diff;
|
|
const int8_t * w = B;
|
|
int8_t * blockC = C + (ie << 2);
|
|
const unsigned int* nnz = NNZMap;
|
|
|
|
size_t ih = 0;
|
|
for (; ih < (h & (~0x03)); ih += sparseBlockOC) {
|
|
auto ihPack = ih >> 2;
|
|
auto c = blockC + ihPack * cStride;
|
|
|
|
int32_t initValue[4] = {0, 0, 0, 0};
|
|
if (nullptr != bias) {
|
|
memcpy(initValue, bias + ih, 4 * sizeof(int32_t));
|
|
}
|
|
int32_t acc0[4];
|
|
int32_t acc1[4];
|
|
int32_t acc2[4];
|
|
int32_t acc3[4];
|
|
int32_t acc4[4];
|
|
int32_t acc5[4];
|
|
int32_t acc6[4];
|
|
int32_t acc7[4];
|
|
int32_t acc8[4];
|
|
int32_t acc9[4];
|
|
int32_t acc10[4];
|
|
int32_t acc11[4];
|
|
int32_t acc12[4];
|
|
int32_t acc13[4];
|
|
int32_t acc14[4];
|
|
int32_t acc15[4];
|
|
|
|
memcpy(acc0, initValue, 4 * sizeof(int32_t));
|
|
memcpy(acc1, initValue, 4 * sizeof(int32_t));
|
|
memcpy(acc2, initValue, 4 * sizeof(int32_t));
|
|
memcpy(acc3, initValue, 4 * sizeof(int32_t));
|
|
memcpy(acc4, initValue, 4 * sizeof(int32_t));
|
|
memcpy(acc5, initValue, 4 * sizeof(int32_t));
|
|
memcpy(acc6, initValue, 4 * sizeof(int32_t));
|
|
memcpy(acc7, initValue, 4 * sizeof(int32_t));
|
|
memcpy(acc8, initValue, 4 * sizeof(int32_t));
|
|
memcpy(acc9, initValue, 4 * sizeof(int32_t));
|
|
memcpy(acc10, initValue, 4 * sizeof(int32_t));
|
|
memcpy(acc11, initValue, 4 * sizeof(int32_t));
|
|
memcpy(acc12, initValue, 4 * sizeof(int32_t));
|
|
memcpy(acc13, initValue, 4 * sizeof(int32_t));
|
|
memcpy(acc14, initValue, 4 * sizeof(int32_t));
|
|
memcpy(acc15, initValue, 4 * sizeof(int32_t));
|
|
|
|
const int lElement = *nnz++;
|
|
for (auto il = 0; il < lElement; il++) {
|
|
|
|
const int diff = *dataOffset++;
|
|
const int8_t a0 = a[0];
|
|
const int8_t a1 = a[1];
|
|
const int8_t a2 = a[2];
|
|
const int8_t a3 = a[3];
|
|
const int8_t a4 = a[4];
|
|
const int8_t a5 = a[5];
|
|
const int8_t a6 = a[6];
|
|
const int8_t a7 = a[7];
|
|
const int8_t a8 = a[8];
|
|
const int8_t a9 = a[9];
|
|
const int8_t a10 = a[10];
|
|
const int8_t a11 = a[11];
|
|
const int8_t a12 = a[12];
|
|
const int8_t a13 = a[13];
|
|
const int8_t a14 = a[14];
|
|
const int8_t a15 = a[15];
|
|
|
|
const int8_t wv[4] = {*w++, *w++, *w++, *w++};
|
|
|
|
// MNN_PRINT("16-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:%f, a value[0-15]:", ie, a - A, w - B - 1, c - C, oneW);
|
|
// formatMatrix(a, {16});
|
|
// MNN_PRINT("\n");
|
|
a = a + diff;
|
|
for (int lane = 0; lane < 4; lane++) {
|
|
acc0[lane] += (int32_t)a0 * (int32_t)wv[lane];
|
|
acc1[lane] += (int32_t)a1 * (int32_t)wv[lane];
|
|
acc2[lane] += (int32_t)a2 * (int32_t)wv[lane];
|
|
acc3[lane] += (int32_t)a3 * (int32_t)wv[lane];
|
|
acc4[lane] += (int32_t)a4 * (int32_t)wv[lane];
|
|
acc5[lane] += (int32_t)a5 * (int32_t)wv[lane];
|
|
acc6[lane] += (int32_t)a6 * (int32_t)wv[lane];
|
|
acc7[lane] += (int32_t)a7 * (int32_t)wv[lane];
|
|
acc8[lane] += (int32_t)a8 * (int32_t)wv[lane];
|
|
acc9[lane] += (int32_t)a9 * (int32_t)wv[lane];
|
|
acc10[lane] += (int32_t)a10 * (int32_t)wv[lane];
|
|
acc11[lane] += (int32_t)a11 * (int32_t)wv[lane];
|
|
acc12[lane] += (int32_t)a12 * (int32_t)wv[lane];
|
|
acc13[lane] += (int32_t)a13 * (int32_t)wv[lane];
|
|
acc14[lane] += (int32_t)a14 * (int32_t)wv[lane];
|
|
acc15[lane] += (int32_t)a15 * (int32_t)wv[lane];
|
|
}
|
|
}
|
|
|
|
int8_t result0[4];
|
|
int8_t result1[4];
|
|
int8_t result2[4];
|
|
int8_t result3[4];
|
|
int8_t result4[4];
|
|
int8_t result5[4];
|
|
int8_t result6[4];
|
|
int8_t result7[4];
|
|
int8_t result8[4];
|
|
int8_t result9[4];
|
|
int8_t result10[4];
|
|
int8_t result11[4];
|
|
int8_t result12[4];
|
|
int8_t result13[4];
|
|
int8_t result14[4];
|
|
int8_t result15[4];
|
|
|
|
if (scales) {
|
|
for (int lane = 0; lane < 4; lane++) {
|
|
result0[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc0[lane])), float(minValue))));
|
|
result1[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc1[lane])), float(minValue))));
|
|
result2[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc2[lane])), float(minValue))));
|
|
result3[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc3[lane])), float(minValue))));
|
|
result4[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc4[lane])), float(minValue))));
|
|
result5[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc5[lane])), float(minValue))));
|
|
result6[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc6[lane])), float(minValue))));
|
|
result7[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc7[lane])), float(minValue))));
|
|
result8[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc8[lane])), float(minValue))));
|
|
result9[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc9[lane])), float(minValue))));
|
|
result10[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc10[lane])), float(minValue))));
|
|
result11[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc11[lane])), float(minValue))));
|
|
result12[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc12[lane])), float(minValue))));
|
|
result13[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc13[lane])), float(minValue))));
|
|
result14[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc14[lane])), float(minValue))));
|
|
result15[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc15[lane])), float(minValue))));
|
|
}
|
|
} else {
|
|
for (int lane = 0; lane < 4; lane++) {
|
|
result0[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc0[lane]), minValue)));
|
|
result1[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc1[lane]), minValue)));
|
|
result2[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc2[lane]), minValue)));
|
|
result3[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc3[lane]), minValue)));
|
|
result4[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc4[lane]), minValue)));
|
|
result5[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc5[lane]), minValue)));
|
|
result6[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc6[lane]), minValue)));
|
|
result7[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc7[lane]), minValue)));
|
|
result8[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc8[lane]), minValue)));
|
|
result9[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc9[lane]), minValue)));
|
|
result10[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc10[lane]), minValue)));
|
|
result11[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc11[lane]), minValue)));
|
|
result12[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc12[lane]), minValue)));
|
|
result13[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc13[lane]), minValue)));
|
|
result14[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc14[lane]), minValue)));
|
|
result15[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc15[lane]), minValue)));
|
|
}
|
|
}
|
|
|
|
memcpy(c , result0, 4 * sizeof(int8_t)); // store continuous c
|
|
memcpy(c + 4 , result1, 4 * sizeof(int8_t));
|
|
memcpy(c + 4 * 2 , result2, 4 * sizeof(int8_t));
|
|
memcpy(c + 4 * 3 , result3, 4 * sizeof(int8_t));
|
|
memcpy(c + 4 * 4 , result4, 4 * sizeof(int8_t));
|
|
memcpy(c + 4 * 5 , result5, 4 * sizeof(int8_t));
|
|
memcpy(c + 4 * 6 , result6, 4 * sizeof(int8_t));
|
|
memcpy(c + 4 * 7 , result7, 4 * sizeof(int8_t));
|
|
memcpy(c + 4 * 8 , result8, 4 * sizeof(int8_t));
|
|
memcpy(c + 4 * 9 , result9, 4 * sizeof(int8_t));
|
|
memcpy(c + 4 * 10, result10, 4 * sizeof(int8_t));
|
|
memcpy(c + 4 * 11, result11, 4 * sizeof(int8_t));
|
|
memcpy(c + 4 * 12, result12, 4 * sizeof(int8_t));
|
|
memcpy(c + 4 * 13, result13, 4 * sizeof(int8_t));
|
|
memcpy(c + 4 * 14, result14, 4 * sizeof(int8_t));
|
|
memcpy(c + 4 * 15, result15, 4 * sizeof(int8_t));
|
|
}
|
|
|
|
blockC += (h >> 2) * cStride;
|
|
for (; ih < h; ih++) {
|
|
auto ihSubIndex = ih & 0x03;
|
|
auto c = blockC + ihSubIndex;
|
|
const int32_t initValue = nullptr != bias ? bias[ih] : 0;
|
|
int32_t acc0 = initValue;
|
|
int32_t acc1 = initValue;
|
|
int32_t acc2 = initValue;
|
|
int32_t acc3 = initValue;
|
|
int32_t acc4 = initValue;
|
|
int32_t acc5 = initValue;
|
|
int32_t acc6 = initValue;
|
|
int32_t acc7 = initValue;
|
|
int32_t acc8 = initValue;
|
|
int32_t acc9 = initValue;
|
|
int32_t acc10 = initValue;
|
|
int32_t acc11 = initValue;
|
|
int32_t acc12 = initValue;
|
|
int32_t acc13 = initValue;
|
|
int32_t acc14 = initValue;
|
|
int32_t acc15 = initValue;
|
|
const int lElement = *nnz++;
|
|
for (auto il = 0; il < lElement; il++) {
|
|
|
|
const int diff = *dataOffset++;
|
|
const int8_t a0 = a[0];
|
|
const int8_t a1 = a[1];
|
|
const int8_t a2 = a[2];
|
|
const int8_t a3 = a[3];
|
|
const int8_t a4 = a[4];
|
|
const int8_t a5 = a[5];
|
|
const int8_t a6 = a[6];
|
|
const int8_t a7 = a[7];
|
|
const int8_t a8 = a[8];
|
|
const int8_t a9 = a[9];
|
|
const int8_t a10 = a[10];
|
|
const int8_t a11 = a[11];
|
|
const int8_t a12 = a[12];
|
|
const int8_t a13 = a[13];
|
|
const int8_t a14 = a[14];
|
|
const int8_t a15 = a[15];
|
|
|
|
const int8_t oneW = *w++;
|
|
|
|
// MNN_PRINT("16-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:%f, a value[0-15]:", ie, a - A, w - B - 1, c - C, oneW);
|
|
// formatMatrix(a, {16});
|
|
// MNN_PRINT("\n");
|
|
a = a + diff;
|
|
acc0 += (int32_t)a0 * (int32_t)oneW;
|
|
acc1 += (int32_t)a1 * (int32_t)oneW;
|
|
acc2 += (int32_t)a2 * (int32_t)oneW;
|
|
acc3 += (int32_t)a3 * (int32_t)oneW;
|
|
acc4 += (int32_t)a4 * (int32_t)oneW;
|
|
acc5 += (int32_t)a5 * (int32_t)oneW;
|
|
acc6 += (int32_t)a6 * (int32_t)oneW;
|
|
acc7 += (int32_t)a7 * (int32_t)oneW;
|
|
acc8 += (int32_t)a8 * (int32_t)oneW;
|
|
acc9 += (int32_t)a9 * (int32_t)oneW;
|
|
acc10 += (int32_t)a10 * (int32_t)oneW;
|
|
acc11 += (int32_t)a11 * (int32_t)oneW;
|
|
acc12 += (int32_t)a12 * (int32_t)oneW;
|
|
acc13 += (int32_t)a13 * (int32_t)oneW;
|
|
acc14 += (int32_t)a14 * (int32_t)oneW;
|
|
acc15 += (int32_t)a15 * (int32_t)oneW;
|
|
}
|
|
|
|
int8_t result0; // in assemmbly code, consider reuse acc0[0-8] bit
|
|
int8_t result1;
|
|
int8_t result2;
|
|
int8_t result3;
|
|
int8_t result4;
|
|
int8_t result5;
|
|
int8_t result6;
|
|
int8_t result7;
|
|
int8_t result8;
|
|
int8_t result9;
|
|
int8_t result10;
|
|
int8_t result11;
|
|
int8_t result12;
|
|
int8_t result13;
|
|
int8_t result14;
|
|
int8_t result15;
|
|
|
|
if (scales) {
|
|
result0 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc0)), float(minValue))));
|
|
result1 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc1)), float(minValue))));
|
|
result2 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc2)), float(minValue))));
|
|
result3 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc3)), float(minValue))));
|
|
result4 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc4)), float(minValue))));
|
|
result5 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc5)), float(minValue))));
|
|
result6 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc6)), float(minValue))));
|
|
result7 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc7)), float(minValue))));
|
|
result8 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc8)), float(minValue))));
|
|
result9 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc9)), float(minValue))));
|
|
result10 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc10)), float(minValue))));
|
|
result11 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc11)), float(minValue))));
|
|
result12 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc12)), float(minValue))));
|
|
result13 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc13)), float(minValue))));
|
|
result14 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc14)), float(minValue))));
|
|
result15 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc15)), float(minValue))));
|
|
} else {
|
|
result0 = static_cast<int8_t>(std::max(std::min(maxValue, acc0), minValue));
|
|
result1 = static_cast<int8_t>(std::max(std::min(maxValue, acc1), minValue));
|
|
result2 = static_cast<int8_t>(std::max(std::min(maxValue, acc2), minValue));
|
|
result3 = static_cast<int8_t>(std::max(std::min(maxValue, acc3), minValue));
|
|
result4 = static_cast<int8_t>(std::max(std::min(maxValue, acc4), minValue));
|
|
result5 = static_cast<int8_t>(std::max(std::min(maxValue, acc5), minValue));
|
|
result6 = static_cast<int8_t>(std::max(std::min(maxValue, acc6), minValue));
|
|
result7 = static_cast<int8_t>(std::max(std::min(maxValue, acc7), minValue));
|
|
result8 = static_cast<int8_t>(std::max(std::min(maxValue, acc8), minValue));
|
|
result9 = static_cast<int8_t>(std::max(std::min(maxValue, acc9), minValue));
|
|
result10 = static_cast<int8_t>(std::max(std::min(maxValue, acc10), minValue));
|
|
result11 = static_cast<int8_t>(std::max(std::min(maxValue, acc11), minValue));
|
|
result12 = static_cast<int8_t>(std::max(std::min(maxValue, acc12), minValue));
|
|
result13 = static_cast<int8_t>(std::max(std::min(maxValue, acc13), minValue));
|
|
result14 = static_cast<int8_t>(std::max(std::min(maxValue, acc14), minValue));
|
|
result15 = static_cast<int8_t>(std::max(std::min(maxValue, acc15), minValue));
|
|
}
|
|
|
|
// how to store faster: st4 / transpose /
|
|
c[0] = result0;
|
|
c[4] = result1;
|
|
c[4 * 2] = result2;
|
|
c[4 * 3] = result3;
|
|
c[4 * 4] = result4;
|
|
c[4 * 5] = result5;
|
|
c[4 * 6] = result6;
|
|
c[4 * 7] = result7;
|
|
c[4 * 8] = result8;
|
|
c[4 * 9] = result9;
|
|
c[4 * 10] = result10;
|
|
c[4 * 11] = result11;
|
|
c[4 * 12] = result12;
|
|
c[4 * 13] = result13;
|
|
c[4 * 14] = result14;
|
|
c[4 * 15] = result15;
|
|
}
|
|
a += aStride;
|
|
}
|
|
if (eSize & 0x08) {
|
|
const int* dataOffset = dataOffsetMap;
|
|
const int diff = *dataOffset++;
|
|
// a = blockA + diff;
|
|
a += diff;
|
|
const int8_t* w = B;
|
|
int8_t* blockC = C + (ie << 2);
|
|
const unsigned int* nnz = NNZMap;
|
|
|
|
size_t ih = 0;
|
|
for (; ih < (h & (~0x03)); ih += sparseBlockOC) {
|
|
auto ihPack = ih >> 2;
|
|
auto c = blockC + ihPack * cStride;
|
|
int32_t initValue[4] = {0, 0, 0, 0};
|
|
if (nullptr != bias) {
|
|
memcpy(initValue, bias + ih, 4 * sizeof(int32_t));
|
|
}
|
|
int32_t acc0[4];
|
|
int32_t acc1[4];
|
|
int32_t acc2[4];
|
|
int32_t acc3[4];
|
|
int32_t acc4[4];
|
|
int32_t acc5[4];
|
|
int32_t acc6[4];
|
|
int32_t acc7[4];
|
|
|
|
memcpy(acc0, initValue, 4 * sizeof(int32_t));
|
|
memcpy(acc1, initValue, 4 * sizeof(int32_t));
|
|
memcpy(acc2, initValue, 4 * sizeof(int32_t));
|
|
memcpy(acc3, initValue, 4 * sizeof(int32_t));
|
|
memcpy(acc4, initValue, 4 * sizeof(int32_t));
|
|
memcpy(acc5, initValue, 4 * sizeof(int32_t));
|
|
memcpy(acc6, initValue, 4 * sizeof(int32_t));
|
|
memcpy(acc7, initValue, 4 * sizeof(int32_t));
|
|
|
|
const int lElement = *nnz++;
|
|
for (auto il = 0; il < lElement; il++) {
|
|
|
|
const int diff = *dataOffset++;
|
|
const int8_t a0 = a[0];
|
|
const int8_t a1 = a[1];
|
|
const int8_t a2 = a[2];
|
|
const int8_t a3 = a[3];
|
|
const int8_t a4 = a[4];
|
|
const int8_t a5 = a[5];
|
|
const int8_t a6 = a[6];
|
|
const int8_t a7 = a[7];
|
|
const int8_t wv[4] = {*w++, *w++, *w++, *w++};
|
|
// MNN_PRINT("8-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value[0-3]:, a value[0-7]:\n", ie, a - A, w - B - 1, c - C);
|
|
// formatMatrix(wv, {4});
|
|
// formatMatrix(a, {8});
|
|
// MNN_PRINT("\n");
|
|
a = a + diff;
|
|
for (int lane = 0; lane < 4; lane++) {
|
|
acc0[lane] += int32_t(a0) * int32_t(wv[lane]);
|
|
acc1[lane] += int32_t(a1) * int32_t(wv[lane]);
|
|
acc2[lane] += int32_t(a2) * int32_t(wv[lane]);
|
|
acc3[lane] += int32_t(a3) * int32_t(wv[lane]);
|
|
acc4[lane] += int32_t(a4) * int32_t(wv[lane]);
|
|
acc5[lane] += int32_t(a5) * int32_t(wv[lane]);
|
|
acc6[lane] += int32_t(a6) * int32_t(wv[lane]);
|
|
acc7[lane] += int32_t(a7) * int32_t(wv[lane]);
|
|
}
|
|
}
|
|
|
|
int8_t result0[4];
|
|
int8_t result1[4];
|
|
int8_t result2[4];
|
|
int8_t result3[4];
|
|
int8_t result4[4];
|
|
int8_t result5[4];
|
|
int8_t result6[4];
|
|
int8_t result7[4];
|
|
|
|
if (scales) {
|
|
for (int lane = 0; lane < 4; lane++) {
|
|
result0[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc0[lane])), float(minValue))));
|
|
result1[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc1[lane])), float(minValue))));
|
|
result2[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc2[lane])), float(minValue))));
|
|
result3[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc3[lane])), float(minValue))));
|
|
result4[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc4[lane])), float(minValue))));
|
|
result5[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc5[lane])), float(minValue))));
|
|
result6[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc6[lane])), float(minValue))));
|
|
result7[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc7[lane])), float(minValue))));
|
|
}
|
|
} else {
|
|
for (int lane = 0; lane < 4; lane++) {
|
|
result0[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc0[lane]), minValue)));
|
|
result1[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc1[lane]), minValue)));
|
|
result2[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc2[lane]), minValue)));
|
|
result3[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc3[lane]), minValue)));
|
|
result4[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc4[lane]), minValue)));
|
|
result5[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc5[lane]), minValue)));
|
|
result6[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc6[lane]), minValue)));
|
|
result7[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc7[lane]), minValue)));
|
|
}
|
|
}
|
|
|
|
memcpy(c , result0, 4 * sizeof(int8_t)); // store continuous c
|
|
memcpy(c + 4 , result1, 4 * sizeof(int8_t));
|
|
memcpy(c + 4 * 2 , result2, 4 * sizeof(int8_t));
|
|
memcpy(c + 4 * 3 , result3, 4 * sizeof(int8_t));
|
|
memcpy(c + 4 * 4 , result4, 4 * sizeof(int8_t));
|
|
memcpy(c + 4 * 5 , result5, 4 * sizeof(int8_t));
|
|
memcpy(c + 4 * 6 , result6, 4 * sizeof(int8_t));
|
|
memcpy(c + 4 * 7 , result7, 4 * sizeof(int8_t));
|
|
|
|
}
|
|
blockC += (ih >> 2) * cStride;
|
|
for (; ih < h; ih++) {
|
|
auto ihSubIndex = ih & 0x03;
|
|
auto c = blockC + ihSubIndex;
|
|
const int32_t initValue = nullptr != bias ? bias[ih] : 0;
|
|
int32_t acc0 = initValue;
|
|
int32_t acc1 = initValue;
|
|
int32_t acc2 = initValue;
|
|
int32_t acc3 = initValue;
|
|
int32_t acc4 = initValue;
|
|
int32_t acc5 = initValue;
|
|
int32_t acc6 = initValue;
|
|
int32_t acc7 = initValue;
|
|
|
|
const int lElement = *nnz++;
|
|
for (auto il = 0; il < lElement; il++) {
|
|
const int diff = *dataOffset++;
|
|
const int8_t a0 = a[0];
|
|
const int8_t a1 = a[1];
|
|
const int8_t a2 = a[2];
|
|
const int8_t a3 = a[3];
|
|
const int8_t a4 = a[4];
|
|
const int8_t a5 = a[5];
|
|
const int8_t a6 = a[6];
|
|
const int8_t a7 = a[7];
|
|
const int8_t oneW = *w++;
|
|
// MNN_PRINT("8-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:%d, a value[0-7]:\n", ie, a - A, w - B - 1, c - C, oneW);
|
|
// formatMatrix(a, {8});
|
|
// MNN_PRINT("\n");
|
|
a = a + diff;
|
|
acc0 += int32_t(a0) * int32_t(oneW);
|
|
acc1 += int32_t(a1) * int32_t(oneW);
|
|
acc2 += int32_t(a2) * int32_t(oneW);
|
|
acc3 += int32_t(a3) * int32_t(oneW);
|
|
acc4 += int32_t(a4) * int32_t(oneW);
|
|
acc5 += int32_t(a5) * int32_t(oneW);
|
|
acc6 += int32_t(a6) * int32_t(oneW);
|
|
acc7 += int32_t(a7) * int32_t(oneW);
|
|
}
|
|
|
|
int8_t result0;
|
|
int8_t result1;
|
|
int8_t result2;
|
|
int8_t result3;
|
|
int8_t result4;
|
|
int8_t result5;
|
|
int8_t result6;
|
|
int8_t result7;
|
|
if (scales) {
|
|
result0 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc0)), float(minValue))));
|
|
result1 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc1)), float(minValue))));
|
|
result2 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc2)), float(minValue))));
|
|
result3 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc3)), float(minValue))));
|
|
result4 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc4)), float(minValue))));
|
|
result5 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc5)), float(minValue))));
|
|
result6 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc6)), float(minValue))));
|
|
result7 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc7)), float(minValue))));
|
|
|
|
} else {
|
|
result0 = static_cast<int8_t>(std::max(std::min(maxValue, acc0), minValue));
|
|
result1 = static_cast<int8_t>(std::max(std::min(maxValue, acc1), minValue));
|
|
result2 = static_cast<int8_t>(std::max(std::min(maxValue, acc2), minValue));
|
|
result3 = static_cast<int8_t>(std::max(std::min(maxValue, acc3), minValue));
|
|
result4 = static_cast<int8_t>(std::max(std::min(maxValue, acc4), minValue));
|
|
result5 = static_cast<int8_t>(std::max(std::min(maxValue, acc5), minValue));
|
|
result6 = static_cast<int8_t>(std::max(std::min(maxValue, acc6), minValue));
|
|
result7 = static_cast<int8_t>(std::max(std::min(maxValue, acc7), minValue));
|
|
}
|
|
|
|
// how to store faster: st4 / transpose /
|
|
c[0] = result0;
|
|
c[4] = result1;
|
|
c[4 * 2] = result2;
|
|
c[4 * 3] = result3;
|
|
c[4 * 4] = result4;
|
|
c[4 * 5] = result5;
|
|
c[4 * 6] = result6;
|
|
c[4 * 7] = result7;
|
|
}
|
|
ie += 8;
|
|
a += 8;
|
|
}
|
|
if (eSize & 0x04) {
|
|
const int* dataOffset = dataOffsetMap;
|
|
const int diff = *dataOffset++;
|
|
// a = blockA + diff;
|
|
a += diff;
|
|
const int8_t* w = B;
|
|
int8_t* blockC = C + (ie << 2);
|
|
const unsigned int* nnz = NNZMap;
|
|
|
|
size_t ih = 0;
|
|
for (; ih < (h & (~0x03)); ih += sparseBlockOC) {
|
|
auto ihPack = ih >> 2;
|
|
auto c = blockC + ihPack * cStride;
|
|
int32_t initValue[4] = {0, 0, 0, 0};
|
|
if (nullptr != bias) {
|
|
memcpy(initValue, bias + ih, 4 * sizeof(int32_t));
|
|
}
|
|
int32_t acc0[4];
|
|
int32_t acc1[4];
|
|
int32_t acc2[4];
|
|
int32_t acc3[4];
|
|
|
|
memcpy(acc0, initValue, 4 * sizeof(int32_t));
|
|
memcpy(acc1, initValue, 4 * sizeof(int32_t));
|
|
memcpy(acc2, initValue, 4 * sizeof(int32_t));
|
|
memcpy(acc3, initValue, 4 * sizeof(int32_t));
|
|
|
|
const int lElement = *nnz++;
|
|
for (auto il = 0; il < lElement; il++) {
|
|
|
|
const int diff = *dataOffset++;
|
|
const int8_t a0 = a[0];
|
|
const int8_t a1 = a[1];
|
|
const int8_t a2 = a[2];
|
|
const int8_t a3 = a[3];
|
|
const int8_t wv[4] = {*w++, *w++, *w++, *w++};
|
|
// MNN_PRINT("4-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:, a value[0-3]:\n", ie, a - A, w - B - 1, c - C);
|
|
// formatMatrix(wv, {4});
|
|
// formatMatrix(a, {4});
|
|
// MNN_PRINT("\n");
|
|
a = a + diff;
|
|
for (int lane = 0; lane < 4; lane++) {
|
|
acc0[lane] += int32_t(a0) * int32_t(wv[lane]);
|
|
acc1[lane] += int32_t(a1) * int32_t(wv[lane]);
|
|
acc2[lane] += int32_t(a2) * int32_t(wv[lane]);
|
|
acc3[lane] += int32_t(a3) * int32_t(wv[lane]);
|
|
}
|
|
}
|
|
|
|
int8_t result0[4];
|
|
int8_t result1[4];
|
|
int8_t result2[4];
|
|
int8_t result3[4];
|
|
|
|
if (scales) {
|
|
for (int lane = 0; lane < 4; lane++) {
|
|
result0[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc0[lane])), float(minValue))));
|
|
result1[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc1[lane])), float(minValue))));
|
|
result2[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc2[lane])), float(minValue))));
|
|
result3[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc3[lane])), float(minValue))));
|
|
}
|
|
} else {
|
|
for (int lane = 0; lane < 4; lane++) {
|
|
result0[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc0[lane]), minValue)));
|
|
result1[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc1[lane]), minValue)));
|
|
result2[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc2[lane]), minValue)));
|
|
result3[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc3[lane]), minValue)));
|
|
}
|
|
}
|
|
|
|
memcpy(c , result0, 4 * sizeof(int8_t)); // store continuous c
|
|
memcpy(c + 4 , result1, 4 * sizeof(int8_t));
|
|
memcpy(c + 4 * 2 , result2, 4 * sizeof(int8_t));
|
|
memcpy(c + 4 * 3 , result3, 4 * sizeof(int8_t));
|
|
|
|
}
|
|
blockC += (ih >> 2) * cStride;
|
|
for (; ih < h; ih++) {
|
|
auto ihSubIndex = ih & 0x03;
|
|
auto c = blockC + ihSubIndex;
|
|
const int32_t initValue = nullptr != bias ? bias[ih] : 0;
|
|
int32_t acc0 = initValue;
|
|
int32_t acc1 = initValue;
|
|
int32_t acc2 = initValue;
|
|
int32_t acc3 = initValue;
|
|
|
|
const int lElement = *nnz++;
|
|
for (auto il = 0; il < lElement; il++) {
|
|
const int diff = *dataOffset++;
|
|
const int8_t a0 = a[0];
|
|
const int8_t a1 = a[1];
|
|
const int8_t a2 = a[2];
|
|
const int8_t a3 = a[3];
|
|
const int8_t oneW = *w++;
|
|
// MNN_PRINT("4-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:%d, a value[0-3]:\n", ie, a - A, w - B - 1, c - C, oneW);
|
|
// formatMatrix(a, {4});
|
|
// MNN_PRINT("\n");
|
|
a = a + diff;
|
|
acc0 += int32_t(a0) * int32_t(oneW);
|
|
acc1 += int32_t(a1) * int32_t(oneW);
|
|
acc2 += int32_t(a2) * int32_t(oneW);
|
|
acc3 += int32_t(a3) * int32_t(oneW);
|
|
}
|
|
|
|
int8_t result0;
|
|
int8_t result1;
|
|
int8_t result2;
|
|
int8_t result3;
|
|
if (scales) {
|
|
result0 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc0)), float(minValue))));
|
|
result1 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc1)), float(minValue))));
|
|
result2 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc2)), float(minValue))));
|
|
result3 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc3)), float(minValue))));
|
|
} else {
|
|
result0 = static_cast<int8_t>(std::max(std::min(maxValue, acc0), minValue));
|
|
result1 = static_cast<int8_t>(std::max(std::min(maxValue, acc1), minValue));
|
|
result2 = static_cast<int8_t>(std::max(std::min(maxValue, acc2), minValue));
|
|
result3 = static_cast<int8_t>(std::max(std::min(maxValue, acc3), minValue));
|
|
}
|
|
|
|
// how to store faster: st4 / transpose /
|
|
c[0] = result0;
|
|
c[4] = result1;
|
|
c[4 * 2] = result2;
|
|
c[4 * 3] = result3;
|
|
}
|
|
ie += 4;
|
|
a += 4;
|
|
}
|
|
if (eSize & 0x02) {
|
|
const int* dataOffset = dataOffsetMap;
|
|
const int diff = *dataOffset++;
|
|
// a = blockA + diff;
|
|
a += diff;
|
|
const int8_t* w = B;
|
|
int8_t* blockC = C + (ie << 2);
|
|
const unsigned int* nnz = NNZMap;
|
|
|
|
size_t ih = 0;
|
|
for (; ih < (h & (~0x03)); ih += sparseBlockOC) {
|
|
auto ihPack = ih >> 2;
|
|
auto c = blockC + ihPack * cStride;
|
|
int32_t initValue[4] = {0, 0, 0, 0};
|
|
if (nullptr != bias) {
|
|
memcpy(initValue, bias + ih, 4 * sizeof(int32_t));
|
|
}
|
|
int32_t acc0[4];
|
|
int32_t acc1[4];
|
|
memcpy(acc0, initValue, 4 * sizeof(int32_t));
|
|
memcpy(acc1, initValue, 4 * sizeof(int32_t));
|
|
|
|
const int lElement = *nnz++;
|
|
for (auto il = 0; il < lElement; il++) {
|
|
|
|
const int diff = *dataOffset++;
|
|
const int8_t a0 = a[0];
|
|
const int8_t a1 = a[1];
|
|
const int8_t wv[4] = {*w++, *w++, *w++, *w++};
|
|
// MNN_PRINT("2-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:, a value[0-1]:\n", ie, a - A, w - B - 1, c - C);
|
|
// formatMatrix(wv, {4});
|
|
// formatMatrix(a, {2});
|
|
// MNN_PRINT("\n");
|
|
a = a + diff;
|
|
for (int lane = 0; lane < 4; lane++) {
|
|
acc0[lane] += int32_t(a0) * int32_t(wv[lane]);
|
|
acc1[lane] += int32_t(a1) * int32_t(wv[lane]);
|
|
}
|
|
}
|
|
|
|
int8_t result0[4];
|
|
int8_t result1[4];
|
|
if (scales) {
|
|
for (int lane = 0; lane < 4; lane++) {
|
|
result0[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc0[lane])), float(minValue))));
|
|
result1[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc1[lane])), float(minValue))));
|
|
}
|
|
} else {
|
|
for (int lane = 0; lane < 4; lane++) {
|
|
result0[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc0[lane]), minValue)));
|
|
result1[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc1[lane]), minValue)));
|
|
}
|
|
}
|
|
|
|
memcpy(c , result0, 4 * sizeof(int8_t)); // store continuous c
|
|
memcpy(c + 4 , result1, 4 * sizeof(int8_t));
|
|
}
|
|
blockC += (ih >> 2) * cStride;
|
|
for (; ih < h; ih++) {
|
|
auto ihSubIndex = ih & 0x03;
|
|
auto c = blockC + ihSubIndex;
|
|
const int32_t initValue = nullptr != bias ? bias[ih] : 0;
|
|
int32_t acc0 = initValue;
|
|
int32_t acc1 = initValue;
|
|
|
|
const int lElement = *nnz++;
|
|
for (auto il = 0; il < lElement; il++) {
|
|
const int diff = *dataOffset++;
|
|
const int8_t a0 = a[0];
|
|
const int8_t a1 = a[1];
|
|
const int8_t oneW = *w++;
|
|
// MNN_PRINT("2-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:%d, a value[0-1]:\n", ie, a - A, w - B - 1, c - C, oneW);
|
|
// formatMatrix(a, {2});
|
|
// MNN_PRINT("\n");
|
|
a = a + diff;
|
|
acc0 += int32_t(a0) * int32_t(oneW);
|
|
acc1 += int32_t(a1) * int32_t(oneW);
|
|
}
|
|
|
|
int8_t result0;
|
|
int8_t result1;
|
|
if (scales) {
|
|
result0 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc0)), float(minValue))));
|
|
result1 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc1)), float(minValue))));
|
|
} else {
|
|
result0 = static_cast<int8_t>(std::max(std::min(maxValue, acc0), minValue));
|
|
result1 = static_cast<int8_t>(std::max(std::min(maxValue, acc1), minValue));
|
|
}
|
|
|
|
// how to store faster: st4 / transpose /
|
|
c[0] = result0;
|
|
c[4] = result1;
|
|
}
|
|
ie += 2;
|
|
a += 2;
|
|
}
|
|
if (eSize & 0x01) {
|
|
const int* dataOffset = dataOffsetMap;
|
|
const int diff = *dataOffset++;
|
|
// const float* a = blockA + diff;
|
|
a += diff;
|
|
const int8_t * w = B;
|
|
int8_t * blockC = C + (ie << 2);
|
|
const unsigned int* nnz = NNZMap;
|
|
|
|
size_t ih = 0;
|
|
for (; ih < (h & (~0x03)); ih += sparseBlockOC) {
|
|
auto ihPack = ih >> 2;
|
|
auto c = blockC + ihPack * cStride;
|
|
int32_t initValue[4] = {0, 0, 0, 0};
|
|
if (nullptr != bias) {
|
|
memcpy(initValue, bias + ih, 4 * sizeof(int32_t));
|
|
}
|
|
int32_t acc0[4];
|
|
memcpy(acc0, initValue, 4 * sizeof(int32_t));
|
|
const int lElement = *nnz++;
|
|
for (auto il = 0; il < lElement; il++) {
|
|
|
|
const int diff = *dataOffset++;
|
|
const int8_t a0 = a[0];
|
|
const int8_t wv[4] = {*w++, *w++, *w++, *w++};
|
|
// MNN_PRINT("16-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:, a value[0-1]:\n", ie, a - A, w - B - 1, c - C);
|
|
// formatMatrix(wv, {4});
|
|
// formatMatrix(a, {16});
|
|
// MNN_PRINT("\n");
|
|
a = a + diff;
|
|
for (int lane = 0; lane < 4; lane++) {
|
|
acc0[lane] += int32_t(a0) * int32_t(wv[lane]);
|
|
}
|
|
}
|
|
|
|
int8_t result0[4];
|
|
if (scales) {
|
|
for (int lane = 0; lane < 4; lane++) {
|
|
result0[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc0[lane])), float(minValue))));
|
|
}
|
|
} else {
|
|
for (int lane = 0; lane < 4; lane++) {
|
|
result0[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc0[lane]), minValue)));
|
|
}
|
|
}
|
|
memcpy(c, result0, 4 * sizeof(int8_t)); // store continuous c
|
|
}
|
|
blockC += (ih >> 2) * cStride;
|
|
for (; ih < h; ih++) {
|
|
auto ihSubIndex = ih & 0x03;
|
|
auto c = blockC + ihSubIndex;
|
|
const int32_t initValue = nullptr != bias ? bias[ih] : 0;
|
|
int32_t acc0 = initValue;
|
|
|
|
const int lElement = *nnz++;
|
|
for (auto il = 0; il < lElement; il++) {
|
|
const int diff = *dataOffset++;
|
|
const int8_t a0 = a[0];
|
|
const int8_t oneW = *w++;
|
|
|
|
// MNN_PRINT("1-loop: ie:%zu, a offset:%ld, c offset:%ld, w offset:%ld, w value:%d, a value[0]:\n", ie, a - A, w - B - 1, c - C, oneW);
|
|
// formatMatrix(a, {1});
|
|
// MNN_PRINT("\n");
|
|
a = a + diff;
|
|
acc0 += int32_t(a0) * int32_t(oneW);
|
|
}
|
|
int8_t result0;
|
|
if (scales) {
|
|
result0 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc0)), float(minValue))));
|
|
} else {
|
|
result0 = static_cast<int8_t>(std::max(std::min(maxValue, acc0), minValue));
|
|
}
|
|
// how to store faster: st4 / transpose /
|
|
c[0] = result0;
|
|
}
|
|
ie += 1;
|
|
// a += 1;
|
|
}
|
|
|
|
}
|
|
|
|
|
|
static int8_t MNNInt32ToInt8(int data, int bias, float scale, float maxValue, float minValue)
|
|
{
|
|
float value = (float)(data + bias) * scale;
|
|
value = ALIMAX(value, minValue);
|
|
value = ALIMIN(value, maxValue);
|
|
return static_cast<int8_t>(roundf(value));
|
|
}
|
|
|
|
static void MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step,
|
|
size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realCount) {
|
|
const int bytes = ((post->useInt8 == 1) ? 1 : 4);
|
|
for (int dz = 0; dz < dst_depth_quad; ++dz) {
|
|
const auto weight_dz = weight + dz * src_depth_quad * (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT);
|
|
const auto bias_dz = post->bias + dz * GEMM_INT8_UNIT;
|
|
const float* scale_dz = nullptr;
|
|
scale_dz = post->scale + dz * GEMM_INT8_UNIT;
|
|
auto dst_z = dst + dz * dst_step;
|
|
for (int w = 0; w < realCount; ++w) {
|
|
const auto src_x = src + w * GEMM_INT8_SRC_UNIT;
|
|
auto dst_x = dst_z + w * GEMM_INT8_UNIT * bytes;
|
|
int32_t dstTemp[4] = {0, 0, 0, 0};
|
|
|
|
for (int sz = 0; sz < src_depth_quad; ++sz) {
|
|
const auto weight_sz = weight_dz + (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT) * sz;
|
|
const auto src_z = src_x + sz * GEMM_INT8_DST_XUNIT * GEMM_INT8_SRC_UNIT;
|
|
|
|
for (int j = 0; j < GEMM_INT8_UNIT; ++j) {
|
|
const auto weight_j = weight_sz + j * GEMM_INT8_SRC_UNIT;
|
|
for (int i = 0; i < GEMM_INT8_SRC_UNIT; ++i) {
|
|
dstTemp[j] += (int32_t)src_z[i] * (int32_t)weight_j[i];
|
|
}
|
|
}
|
|
}
|
|
|
|
for (int j = 0; j < GEMM_INT8_UNIT; ++j) {
|
|
if (!post->scale) {
|
|
((float*)dst_x)[j] = (float)(dstTemp[j] + bias_dz[j]);
|
|
} else if (post->useInt8 == 1) {
|
|
dst_x[j] = MNNInt32ToInt8(dstTemp[j], bias_dz[j], scale_dz[j], post->maxValue, post->minValue);
|
|
} else {
|
|
float value = (float)(dstTemp[j] + bias_dz[j]) * scale_dz[j];
|
|
((float*)dst_x)[j] = value;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
static void MNNGemmInt8AddBiasScale_16x4_Unit_FAST(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realCount) {
|
|
return MNNGemmInt8AddBiasScale_16x4_Unit(dst, src, weight, src_depth_quad, dst_step, dst_depth_quad, post, realCount);
|
|
}
|
|
|
|
static void MNNLineDepthWiseInt8AddBiasScaleUnit(int8_t* dst, const int8_t* src, const int8_t* weight, const QuanPostTreatParameters* parameters,
|
|
size_t width, size_t src_w_step, size_t fw, size_t fh, size_t dilateX_step,
|
|
size_t dilateY_step) {
|
|
auto bias_z = parameters->bias;
|
|
auto scale_z = parameters->scale;
|
|
int dx, fx, fy;
|
|
for (dx = 0; dx < width; ++dx) {
|
|
auto dst_x = dst + dx * 4;
|
|
int32_t dstInt32[4] = {0, 0, 0, 0};
|
|
const auto src_z = src + src_w_step * dx;
|
|
for (fy = 0; fy < fh; ++fy) {
|
|
const auto src_y = src_z + fy * dilateY_step;
|
|
const auto weight_y = weight + fy * fw * 4;
|
|
for (fx = 0; fx < fw; ++fx) {
|
|
const auto src_x = src_y + fx * dilateX_step;
|
|
const auto weight_x = weight_y + 4 * fx;
|
|
for (int j = 0; j < GEMM_INT8_UNIT; ++j) {
|
|
dstInt32[j] += (int32_t)src_x[j] * (int32_t)weight_x[j];
|
|
}
|
|
}
|
|
}
|
|
|
|
for (int i = 0; i < GEMM_INT8_UNIT; ++i) {
|
|
dst_x[i] = MNNInt32ToInt8(dstInt32[i], bias_z[i], scale_z[i], parameters->maxValue, parameters->minValue);
|
|
}
|
|
}
|
|
}
|
|
#endif
|
|
|
|
#ifndef MNN_USE_NEON
|
|
void MNNFloat2Int8(const float* src, int8_t* dst, size_t sizeQuad, const float* scalep, ssize_t minValue,
|
|
ssize_t maxValue, ssize_t zeroPoint) {
|
|
for (int i = 0; i < sizeQuad; ++i) {
|
|
for (int j=0; j<4; ++j) {
|
|
int v = (int)roundf(src[4*i+j] * scalep[j]) + zeroPoint;
|
|
if (v > maxValue) {
|
|
v = maxValue;
|
|
}
|
|
if (v < minValue) {
|
|
v = minValue;
|
|
}
|
|
dst[4*i+j] = v;
|
|
}
|
|
}
|
|
}
|
|
|
|
void MNNInt8ScaleToFloat(float* dst, const int8_t* src, const float* scale, size_t size, ssize_t zeroPoint) {
|
|
for (int i = 0; i < size; ++i) {
|
|
const auto srcStart = src + i * 4;
|
|
auto dstStart = dst + i * 4;
|
|
for (int j = 0; j < 4; ++j) {
|
|
dstStart[j] = static_cast<float>(srcStart[j] - zeroPoint) * scale[j];
|
|
}
|
|
}
|
|
}
|
|
|
|
void MNNAvgPoolInt8(int8_t* dst, int8_t* src, size_t outputWidth, size_t inputWidth, size_t kernelx, size_t kernely, size_t stridesx, ssize_t paddingx, ssize_t factor) {
|
|
int pack = 16;
|
|
int8_t* dstPtr = dst;
|
|
const int8_t* srcPtr = src;
|
|
for (int ox = 0; ox < outputWidth; ++ox) {
|
|
std::vector<int> sum_(pack, 0);
|
|
for (int y = 0; y < kernely; ++y) {
|
|
for (int x = 0; x < kernelx; ++x) {
|
|
const int8_t *inputPtr = srcPtr + pack* (x + inputWidth* y);
|
|
for (int idx = 0; idx < pack; ++idx) {
|
|
sum_[idx] += *(inputPtr + idx);
|
|
}
|
|
}
|
|
}
|
|
for (int idx = 0; idx < pack; ++idx) {
|
|
*(dstPtr + idx) = static_cast<int8_t>((sum_[idx] * factor)>>24);
|
|
}
|
|
dstPtr = dstPtr + pack;
|
|
srcPtr = srcPtr + pack* stridesx;
|
|
}
|
|
}
|
|
|
|
void MNNMaxPoolInt8(int8_t* dst, int8_t* src, size_t outputWidth, size_t inputWidth, size_t kernelx, size_t kernely, size_t stridesx) {
|
|
int pack = 16;
|
|
int8_t* dstPtr = dst;
|
|
const int8_t* srcPtr = src;
|
|
for (int ox = 0; ox < outputWidth; ++ox){
|
|
std::vector<int8_t> results(pack, INT8_MIN);
|
|
for (int y = 0; y < kernely; ++y) {
|
|
for (int x = 0; x < kernelx; ++x) {
|
|
const int8_t* inputPtr = srcPtr + pack* (x + inputWidth* y);
|
|
for (int idx = 0; idx < pack; ++idx) {
|
|
results[idx] = std::max(results[idx], *(inputPtr + idx));
|
|
}
|
|
}
|
|
}
|
|
|
|
for (int idx = 0; idx < pack;++idx) {
|
|
*(dstPtr + idx) = results[idx];
|
|
}
|
|
dstPtr = dstPtr + pack;
|
|
srcPtr = srcPtr + pack* stridesx;
|
|
}
|
|
}
|
|
|
|
void MNNBinaryAddInt8(int8_t* outputRaw, const int8_t* inputRaw0, const int8_t* inputRaw1, const float* inputScale0, const float* inputScale1, const float* outputScale, int elementSize, int needBroadcast) {
|
|
float sum = 0;
|
|
#ifdef MNN_USE_SSE
|
|
const int zeroPoint = 128;
|
|
const int maxValue = 255;
|
|
const int minValue = 0;
|
|
const uint8_t* inputData0 = (uint8_t*)inputRaw0;
|
|
const uint8_t* inputData1 = (uint8_t*)inputRaw1;
|
|
uint8_t* outputData = (uint8_t*)outputRaw;
|
|
#else
|
|
const int zeroPoint = 0;
|
|
const int maxValue = 127;
|
|
const int minValue = -128;
|
|
const int8_t* inputData0 = inputRaw0;
|
|
const int8_t* inputData1 = inputRaw1;
|
|
int8_t* outputData = outputRaw;
|
|
#endif
|
|
for (int i = 0; i < elementSize; ++i) {
|
|
if (needBroadcast == 0) {
|
|
float inp0 = (inputData0[0] - zeroPoint) * inputScale0[i];
|
|
float inp1 = (inputData1[i] - zeroPoint) * inputScale1[i];
|
|
sum = inp0 + inp1;
|
|
} else if (needBroadcast == 1) {
|
|
float inp0 = (inputData0[i] - zeroPoint) * inputScale0[i];
|
|
float inp1 = (inputData1[0] - zeroPoint) * inputScale1[i];
|
|
sum = inp0 + inp1;
|
|
} else {
|
|
float inp0 = (inputData0[i] - zeroPoint) * inputScale0[i];
|
|
float inp1 = (inputData1[i] - zeroPoint) * inputScale1[i];
|
|
sum = inp0 + inp1;
|
|
}
|
|
int value = (int)roundf(sum * outputScale[i]) + zeroPoint;
|
|
if (value > maxValue) {
|
|
value = maxValue;
|
|
}
|
|
if (value < minValue) {
|
|
value = minValue;
|
|
}
|
|
outputData[i] = value;
|
|
}
|
|
}
|
|
|
|
void MNNBinarySubInt8(int8_t* outputRaw, const int8_t* inputRaw0, const int8_t* inputRaw1, const float* inputScale0, const float* inputScale1, const float* outputScale, int elementSize, int needBroadcast) {
|
|
float res = 0;
|
|
#ifdef MNN_USE_SSE
|
|
const int zeroPoint = 128;
|
|
const int maxValue = 255;
|
|
const int minValue = 0;
|
|
const uint8_t* inputData0 = (uint8_t*)inputRaw0;
|
|
const uint8_t* inputData1 = (uint8_t*)inputRaw1;
|
|
uint8_t* outputData = (uint8_t*)outputRaw;
|
|
#else
|
|
const int zeroPoint = 0;
|
|
const int maxValue = 127;
|
|
const int minValue = -128;
|
|
const int8_t* inputData0 = inputRaw0;
|
|
const int8_t* inputData1 = inputRaw1;
|
|
int8_t* outputData = outputRaw;
|
|
#endif
|
|
for (int i = 0; i < elementSize; ++i) {
|
|
if (needBroadcast == 0) {
|
|
float inp0 = (inputData0[0] - zeroPoint) * inputScale0[i];
|
|
float inp1 = (inputData1[i] - zeroPoint) * inputScale1[i];
|
|
res = inp0 - inp1;
|
|
} else if (needBroadcast == 1) {
|
|
float inp0 = (inputData0[i] - zeroPoint) * inputScale0[i];
|
|
float inp1 = (inputData1[0] - zeroPoint) * inputScale1[i];
|
|
res = inp0 - inp1;
|
|
} else {
|
|
float inp0 = (inputData0[i] - zeroPoint) * inputScale0[i];
|
|
float inp1 = (inputData1[i] - zeroPoint) * inputScale1[i];
|
|
res = inp0 - inp1;
|
|
}
|
|
int value = (int)roundf(res * outputScale[i]) + zeroPoint;
|
|
if (value > maxValue) {
|
|
value = maxValue;
|
|
}
|
|
if (value < minValue) {
|
|
value = minValue;
|
|
}
|
|
outputData[i] = value;
|
|
}
|
|
}
|
|
|
|
void MNNBinaryMulInt8(int8_t* outputRaw, const int8_t* inputRaw0, const int8_t* inputRaw1, const float* inputScale0, const float* inputScale1, const float* outputScale, int elementSize, int needBroadcast) {
|
|
float res = 0;
|
|
#ifdef MNN_USE_SSE
|
|
const int zeroPoint = 128;
|
|
const int maxValue = 255;
|
|
const int minValue = 0;
|
|
const uint8_t* inputData0 = (uint8_t*)inputRaw0;
|
|
const uint8_t* inputData1 = (uint8_t*)inputRaw1;
|
|
uint8_t* outputData = (uint8_t*)outputRaw;
|
|
#else
|
|
const int zeroPoint = 0;
|
|
const int maxValue = 127;
|
|
const int minValue = -128;
|
|
const int8_t* inputData0 = inputRaw0;
|
|
const int8_t* inputData1 = inputRaw1;
|
|
int8_t* outputData = outputRaw;
|
|
#endif
|
|
for (int i = 0; i < elementSize; ++i) {
|
|
if (needBroadcast == 0) {
|
|
float inp0 = (inputData0[0] - zeroPoint) * inputScale0[i];
|
|
float inp1 = (inputData1[i] - zeroPoint) * inputScale1[i];
|
|
res = inp0 * inp1;
|
|
} else if (needBroadcast == 1) {
|
|
float inp0 = (inputData0[i] - zeroPoint) * inputScale0[i];
|
|
float inp1 = (inputData1[0] - zeroPoint) * inputScale1[i];
|
|
res = inp0 * inp1;
|
|
} else {
|
|
float inp0 = (inputData0[i] - zeroPoint) * inputScale0[i];
|
|
float inp1 = (inputData1[i] - zeroPoint) * inputScale1[i];
|
|
res = inp0 * inp1;
|
|
}
|
|
int value = (int)roundf(res * outputScale[i]) + zeroPoint;
|
|
if (value > maxValue) {
|
|
value = maxValue;
|
|
}
|
|
if (value < minValue) {
|
|
value = minValue;
|
|
}
|
|
outputData[i] = value;
|
|
}
|
|
}
|
|
|
|
void MNNBinaryMinInt8(int8_t* outputRaw, const int8_t* inputRaw0, const int8_t* inputRaw1, const float* inputScale0, const float* inputScale1, const float* outputScale, int elementSize, int needBroadcast) {
|
|
float res = 0;
|
|
#ifdef MNN_USE_SSE
|
|
const int zeroPoint = 128;
|
|
const int maxValue = 255;
|
|
const int minValue = 0;
|
|
const uint8_t* inputData0 = (uint8_t*)inputRaw0;
|
|
const uint8_t* inputData1 = (uint8_t*)inputRaw1;
|
|
uint8_t* outputData = (uint8_t*)outputRaw;
|
|
#else
|
|
const int zeroPoint = 0;
|
|
const int maxValue = 127;
|
|
const int minValue = -128;
|
|
const int8_t* inputData0 = inputRaw0;
|
|
const int8_t* inputData1 = inputRaw1;
|
|
int8_t* outputData = outputRaw;
|
|
#endif
|
|
for (int i = 0; i < elementSize; ++i) {
|
|
if (needBroadcast == 0) {
|
|
float inp0 = (inputData0[0] - zeroPoint) * inputScale0[i];
|
|
float inp1 = (inputData1[i] - zeroPoint) * inputScale1[i];
|
|
res = std::min(inp0, inp1);
|
|
} else if (needBroadcast == 1) {
|
|
float inp0 = (inputData0[i] - zeroPoint) * inputScale0[i];
|
|
float inp1 = (inputData1[0] - zeroPoint) * inputScale1[i];
|
|
res = std::min(inp0, inp1);
|
|
} else {
|
|
float inp0 = (inputData0[i] - zeroPoint) * inputScale0[i];
|
|
float inp1 = (inputData1[i] - zeroPoint) * inputScale1[i];
|
|
res = std::min(inp0, inp1);
|
|
}
|
|
int value = (int)roundf(res * outputScale[i]) + zeroPoint;
|
|
if (value > maxValue) {
|
|
value = maxValue;
|
|
}
|
|
if (value < minValue) {
|
|
value = minValue;
|
|
}
|
|
outputData[i] = value;
|
|
}
|
|
}
|
|
|
|
void MNNBinaryMaxInt8(int8_t* outputRaw, const int8_t* inputRaw0, const int8_t* inputRaw1, const float* inputScale0, const float* inputScale1, const float* outputScale, int elementSize, int needBroadcast) {
|
|
float res = 0;
|
|
#ifdef MNN_USE_SSE
|
|
const int zeroPoint = 128;
|
|
const int maxValue = 255;
|
|
const int minValue = 0;
|
|
const uint8_t* inputData0 = (uint8_t*)inputRaw0;
|
|
const uint8_t* inputData1 = (uint8_t*)inputRaw1;
|
|
uint8_t* outputData = (uint8_t*)outputRaw;
|
|
#else
|
|
const int zeroPoint = 0;
|
|
const int maxValue = 127;
|
|
const int minValue = -128;
|
|
const int8_t* inputData0 = inputRaw0;
|
|
const int8_t* inputData1 = inputRaw1;
|
|
int8_t* outputData = outputRaw;
|
|
#endif
|
|
for (int i = 0; i < elementSize; ++i) {
|
|
if (needBroadcast == 0) {
|
|
float inp0 = (inputData0[0] - zeroPoint) * inputScale0[i];
|
|
float inp1 = (inputData1[i] - zeroPoint) * inputScale1[i];
|
|
res = std::max(inp0, inp1);
|
|
} else if (needBroadcast == 1) {
|
|
float inp0 = (inputData0[i] - zeroPoint) * inputScale0[i];
|
|
float inp1 = (inputData1[0] - zeroPoint) * inputScale1[i];
|
|
res = std::max(inp0, inp1);
|
|
} else {
|
|
float inp0 = (inputData0[i] - zeroPoint) * inputScale0[i];
|
|
float inp1 = (inputData1[i] - zeroPoint) * inputScale1[i];
|
|
res = std::max(inp0, inp1);
|
|
}
|
|
int value = (int)roundf(res * outputScale[i]) + zeroPoint;
|
|
if (value > maxValue) {
|
|
value = maxValue;
|
|
}
|
|
if (value < minValue) {
|
|
value = minValue;
|
|
}
|
|
outputData[i] = value;
|
|
}
|
|
}
|
|
void MNNBinarySqdInt8(int8_t* outputRaw, const int8_t* inputRaw0, const int8_t* inputRaw1, const float* inputScale0, const float* inputScale1, const float* outputScale, int elementSize, int needBroadcast) {
|
|
float res = 0;
|
|
#ifdef MNN_USE_SSE
|
|
const int zeroPoint = 128;
|
|
const int maxValue = 255;
|
|
const int minValue = 0;
|
|
const uint8_t* inputData0 = (uint8_t*)inputRaw0;
|
|
const uint8_t* inputData1 = (uint8_t*)inputRaw1;
|
|
uint8_t* outputData = (uint8_t*)outputRaw;
|
|
#else
|
|
const int zeroPoint = 0;
|
|
const int maxValue = 127;
|
|
const int minValue = -128;
|
|
const int8_t* inputData0 = inputRaw0;
|
|
const int8_t* inputData1 = inputRaw1;
|
|
int8_t* outputData = outputRaw;
|
|
#endif
|
|
for (int i = 0; i < elementSize; ++i) {
|
|
if (needBroadcast == 0) {
|
|
float inp0 = (inputData0[0] - zeroPoint) * inputScale0[i];
|
|
float inp1 = (inputData1[i] - zeroPoint) * inputScale1[i];
|
|
res = (inp0 - inp1) * (inp0 - inp1);
|
|
} else if (needBroadcast == 1) {
|
|
float inp0 = (inputData0[i] - zeroPoint) * inputScale0[i];
|
|
float inp1 = (inputData1[0] - zeroPoint) * inputScale1[i];
|
|
res = (inp0 - inp1) * (inp0 - inp1);
|
|
} else {
|
|
float inp0 = (inputData0[i] - zeroPoint) * inputScale0[i];
|
|
float inp1 = (inputData1[i] - zeroPoint) * inputScale1[i];
|
|
res = (inp0 - inp1) * (inp0 - inp1);
|
|
}
|
|
int value = (int)roundf(res * outputScale[i]) + zeroPoint;
|
|
if (value > maxValue) {
|
|
value = maxValue;
|
|
}
|
|
if (value < minValue) {
|
|
value = minValue;
|
|
}
|
|
outputData[i] = value;
|
|
}
|
|
}
|
|
|
|
|
|
#endif // #ifndef MNN_USE_NEON
|
|
#ifndef MNN_USE_SSE
|
|
|
|
void MNNInt8FunctionInit() {
|
|
// do nothing
|
|
}
|
|
#endif // #ifndef MNN_USE_SSE
|
|
|
|
/* CPU without sdot */
|
|
// Assume GEMM_INT8_UNIT == 4 && GEMM_INT8_SRC_UNIT == 16
|
|
static void _fastIm2Col(int8_t* colAddr, const int8_t* inputOrigin, int32_t inputZeroPoint,
|
|
const MNN::ConvolutionCommon::Im2ColParameter* im2colParameter, size_t xIndexStart,
|
|
size_t realDstCount) {
|
|
const int col_buffer_size = im2colParameter->kernelCountUnit * GEMM_INT8_SRC_UNIT * GEMM_INT8_DST_XUNIT * sizeof(int8_t);
|
|
::memset(colAddr, inputZeroPoint, col_buffer_size); // the padding process, since per-channel is removed, this is all right
|
|
|
|
const int icDiv8 = im2colParameter->icDiv4 / 2;
|
|
const int srcZStep = im2colParameter->iw * im2colParameter->ih * GEMM_INT8_UNIT;
|
|
inputOrigin += xIndexStart * GEMM_INT8_UNIT;
|
|
for (int i = 0; i < realDstCount; ++i) {
|
|
auto colAddrI = colAddr + GEMM_INT8_SRC_UNIT * i;
|
|
auto inputK = inputOrigin + GEMM_INT8_UNIT * i;
|
|
for (int sz = 0; sz < icDiv8; ++sz) {
|
|
auto inputZ0 = inputK + srcZStep * (2 * sz + 0);
|
|
auto inputZ1 = inputK + srcZStep * (2 * sz + 1);
|
|
const int indexOutside = sz / 2;
|
|
const int indexInsize = sz % 2;
|
|
|
|
auto dstK0 = colAddrI + (indexOutside * GEMM_INT8_DST_XUNIT * 2 + indexInsize) * (2 * GEMM_INT8_UNIT);
|
|
auto dstK1 = dstK0 + GEMM_INT8_UNIT;
|
|
*((int32_t*)dstK0) = *((int32_t*)inputZ0);
|
|
*((int32_t*)dstK1) = *((int32_t*)inputZ1);
|
|
}
|
|
}
|
|
}
|
|
|
|
static void _im2colCommonZ1(int8_t* colAddr, const int8_t* inputOrigin, int32_t inputZeroPoint,
|
|
const MNN::ConvolutionCommon::Im2ColParameter* im2colParameter, size_t xIndexStart,
|
|
size_t realDstCount) {
|
|
int col_buffer_size = im2colParameter->kernelCountUnit * GEMM_INT8_DST_XUNIT * GEMM_INT8_SRC_UNIT * sizeof(int8_t);
|
|
::memset(colAddr, inputZeroPoint, col_buffer_size); // the padding process, since per-channel is removed, this is all right
|
|
|
|
auto ih = im2colParameter->ih;
|
|
auto iw = im2colParameter->iw;
|
|
auto kh = im2colParameter->kernelY;
|
|
auto kw = im2colParameter->kernelX;
|
|
auto dilateX = im2colParameter->dilateX;
|
|
auto dilateY = im2colParameter->dilateY;
|
|
auto srcYStep = im2colParameter->srcYStep;
|
|
constexpr int dstXStepInt32 = GEMM_INT8_SRC_UNIT * GEMM_INT8_DST_XUNIT / sizeof(int32_t);
|
|
for (int i = 0; i < realDstCount; ++i) {
|
|
int xIndex = (int)xIndexStart + i;
|
|
int ox = xIndex % im2colParameter->ow;
|
|
int oy = xIndex / im2colParameter->ow;
|
|
|
|
int sx = ox * im2colParameter->strideX - im2colParameter->padX;
|
|
int sy = oy * im2colParameter->strideY - im2colParameter->padY;
|
|
|
|
int sfy = ALIMAX(0, (UP_DIV(-sy, im2colParameter->dilateY)));
|
|
int efy = ALIMIN(kh, UP_DIV(ih - sy, im2colParameter->dilateY));
|
|
int sfx = ALIMAX(0, (UP_DIV(-sx, im2colParameter->dilateX)));
|
|
int efx = ALIMIN(kw, UP_DIV(iw - sx, im2colParameter->dilateX));
|
|
int fyC = efy - sfy;
|
|
int fxC = efx - sfx;
|
|
|
|
auto colAddrI = colAddr + GEMM_INT8_SRC_UNIT * i;
|
|
|
|
auto inputOffset = inputOrigin + (sy + sfy * dilateY) * srcYStep + (sx + sfx * dilateX) * GEMM_INT8_UNIT;
|
|
auto indexOffset = sfy * kw + sfx;
|
|
for (int fy = 0; fy < fyC; ++fy) {
|
|
for (int fx = 0; fx < fxC; ++fx) {
|
|
auto inputK = inputOffset + fy * dilateY * srcYStep + fx * dilateX * GEMM_INT8_UNIT;
|
|
auto indexStart = indexOffset + fy * kw + fx;
|
|
auto indexInside = indexStart % 4;
|
|
auto indexOutside = indexStart / 4;
|
|
auto dstK0 = (int32_t*)colAddrI + indexOutside * dstXStepInt32 + indexInside;
|
|
dstK0[0] = *((int32_t*)inputK);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
static void _im2colCommon(int8_t* colAddr, const int8_t* inputOrigin, int32_t inputZeroPoint,
|
|
const MNN::ConvolutionCommon::Im2ColParameter* im2colParameter, size_t xIndexStart,
|
|
size_t realDstCount) {
|
|
const int col_buffer_size = im2colParameter->kernelCountUnit * GEMM_INT8_DST_XUNIT * GEMM_INT8_SRC_UNIT * sizeof(int8_t);
|
|
::memset(colAddr, inputZeroPoint, col_buffer_size); // the padding process, since per-channel is removed, this is all right
|
|
|
|
auto ih = im2colParameter->ih;
|
|
auto iw = im2colParameter->iw;
|
|
auto kh = im2colParameter->kernelY;
|
|
auto kw = im2colParameter->kernelX;
|
|
auto dilateX = im2colParameter->dilateX;
|
|
auto dilateY = im2colParameter->dilateY;
|
|
auto icDiv4 = im2colParameter->icDiv4;
|
|
auto srcZStep = im2colParameter->srcZStep;
|
|
auto srcYStep = im2colParameter->srcYStep;
|
|
constexpr int dstXStepInt32 = GEMM_INT8_SRC_UNIT * GEMM_INT8_DST_XUNIT / sizeof(int32_t);
|
|
for (int i = 0; i < realDstCount; ++i) {
|
|
int xIndex = (int)xIndexStart + i;
|
|
int ox = xIndex % im2colParameter->ow;
|
|
int oy = xIndex / im2colParameter->ow;
|
|
|
|
int sx = ox * im2colParameter->strideX - im2colParameter->padX;
|
|
int sy = oy * im2colParameter->strideY - im2colParameter->padY;
|
|
|
|
int sfy = ALIMAX(0, (UP_DIV(-sy, im2colParameter->dilateY)));
|
|
int efy = ALIMIN(kh, UP_DIV(ih - sy, im2colParameter->dilateY));
|
|
int sfx = ALIMAX(0, (UP_DIV(-sx, im2colParameter->dilateX)));
|
|
int efx = ALIMIN(kw, UP_DIV(iw - sx, im2colParameter->dilateX));
|
|
int fyC = efy - sfy;
|
|
int fxC = efx - sfx;
|
|
|
|
auto colAddrI = colAddr + GEMM_INT8_SRC_UNIT * i;
|
|
|
|
auto inputOffset = inputOrigin + (sy + sfy * dilateY) * srcYStep + (sx + sfx * dilateX) * GEMM_INT8_UNIT;
|
|
auto indexOffset = (sfy * kw + sfx) * icDiv4;
|
|
for (int fy = 0; fy < fyC; ++fy) {
|
|
for (int fx = 0; fx < fxC; ++fx) {
|
|
auto inputK = inputOffset + fy * dilateY * srcYStep + fx * dilateX * GEMM_INT8_UNIT;
|
|
auto indexStart = indexOffset + (fy * kw + fx) * icDiv4;
|
|
for (int sz = 0; sz < icDiv4; ++sz) {
|
|
const int yIndex = indexStart + sz;
|
|
const int ySubOutside = yIndex / GEMM_INT8_UNIT;
|
|
const int ySubInside = yIndex % GEMM_INT8_UNIT;
|
|
auto dstK0 = (int32_t*)colAddrI + ySubOutside * dstXStepInt32 + ySubInside;
|
|
dstK0[0] = *((int32_t*)inputK);
|
|
inputK += srcZStep;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
static MNN::CoreInt8Functions::Im2ColFunc chooseIm2Col(const MNN::ConvolutionCommon::Im2ColParameter* im2colParam, size_t inputChannel) {
|
|
bool fastIm2Col = im2colParam->kernelX == 1 && im2colParam->kernelY == 1 && im2colParam->icDiv4 % 2 == 0 &&
|
|
im2colParam->strideX == 1 && im2colParam->strideY == 1 && im2colParam->padX == 0 &&
|
|
im2colParam->padY == 0;
|
|
int ih = im2colParam->ih, iw = im2colParam->iw;
|
|
fastIm2Col &= (im2colParam->srcYStep == iw * GEMM_INT8_UNIT && im2colParam->srcZStep == ih * iw * GEMM_INT8_UNIT);
|
|
if (fastIm2Col) {
|
|
return _fastIm2Col;
|
|
} else if (inputChannel <= 4) {
|
|
return _im2colCommonZ1;
|
|
} else {
|
|
return _im2colCommon;
|
|
}
|
|
}
|
|
|
|
static void MNNGetGemmUnit(int* UNIT, int* SRC_UNIT, int* DST_XUNIT) {
|
|
*UNIT = GEMM_INT8_UNIT;
|
|
*SRC_UNIT = GEMM_INT8_SRC_UNIT;
|
|
*DST_XUNIT = GEMM_INT8_DST_XUNIT;
|
|
}
|
|
#undef GEMM_INT8_UNIT
|
|
#undef GEMM_INT8_SRC_UNIT
|
|
#undef GEMM_INT8_DST_XUNIT
|
|
/* End */
|
|
|
|
/* CPU with sdot */
|
|
#define GEMM_INT8_UNIT 4
|
|
#define GEMM_INT8_SRC_UNIT 4
|
|
|
|
#ifdef __aarch64__
|
|
#define GEMM_INT8_DST_XUNIT 12
|
|
#else
|
|
#define GEMM_INT8_DST_XUNIT 8
|
|
#endif
|
|
|
|
static void _im2colCommonSdot(int8_t* colAddr, const int8_t* src, int32_t inputZeroPoint,
|
|
const MNN::ConvolutionCommon::Im2ColParameter* im2colParameter, size_t xIndexStart,
|
|
size_t realDstCount) {
|
|
const int colBufferSize = im2colParameter->kernelCountUnit * GEMM_INT8_DST_XUNIT * GEMM_INT8_SRC_UNIT * sizeof(int8_t);
|
|
memset(colAddr, inputZeroPoint, colBufferSize);
|
|
auto ih = im2colParameter->ih;
|
|
auto iw = im2colParameter->iw;
|
|
// auto oh = im2colParameter->oh;
|
|
auto ow = im2colParameter->ow;
|
|
auto kh = im2colParameter->kernelY;
|
|
auto kw = im2colParameter->kernelX;
|
|
auto dilateX = im2colParameter->dilateX;
|
|
auto dilateY = im2colParameter->dilateY;
|
|
auto icDiv4 = im2colParameter->icDiv4;
|
|
auto srcChannleStride = im2colParameter->srcZStep;
|
|
auto srcYStep = im2colParameter->srcYStep;
|
|
constexpr int dstXStepInt32 = GEMM_INT8_UNIT * GEMM_INT8_DST_XUNIT / sizeof(int32_t);
|
|
|
|
for (int i = 0; i < realDstCount; ++i) {
|
|
int xIndex = (int)xIndexStart + i;
|
|
int ox = xIndex % ow;
|
|
int oy = xIndex / ow;
|
|
int sx = ox * im2colParameter->strideX - im2colParameter->padX;
|
|
int sy = oy * im2colParameter->strideY - im2colParameter->padY;
|
|
int sfy = ALIMAX(0, (UP_DIV(-sy, im2colParameter->dilateY)));
|
|
int efy = ALIMIN(kh, UP_DIV(ih - sy, im2colParameter->dilateY));
|
|
int sfx = ALIMAX(0, (UP_DIV(-sx, im2colParameter->dilateX)));
|
|
int efx = ALIMIN(kw, UP_DIV(iw - sx, im2colParameter->dilateX));
|
|
int fyC = efy - sfy;
|
|
int fxC = efx - sfx;
|
|
|
|
auto colAddrI = colAddr + GEMM_INT8_UNIT * i;
|
|
auto inputOffset = src + (sy + sfy * dilateY) * srcYStep + (sx + sfx * dilateX) * GEMM_INT8_UNIT;
|
|
auto indexOffset = (sfy * kw + sfx) * icDiv4;
|
|
|
|
for (int fy = 0; fy < fyC; ++fy) {
|
|
for (int fx = 0; fx < fxC; ++fx) {
|
|
auto inputK = inputOffset + fy * dilateY * srcYStep + fx * dilateX * GEMM_INT8_UNIT;
|
|
auto indexStart = (indexOffset + (fy * kw + fx) * icDiv4) * dstXStepInt32;
|
|
for (int sz = 0; sz < icDiv4; ++sz) {
|
|
auto dstK0 = (int32_t*)colAddrI + indexStart + sz * dstXStepInt32;
|
|
dstK0[0] = *((int32_t*)inputK);
|
|
inputK += srcChannleStride;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
static void _fastIm2ColSdot(int8_t* colAddr, const int8_t* inputOrigin, int32_t inputZeroPoint,
|
|
const MNN::ConvolutionCommon::Im2ColParameter* im2colParameter, size_t xIndexStart,
|
|
size_t realDstCount) {
|
|
const int col_buffer_size = im2colParameter->kernelCountUnit * GEMM_INT8_DST_XUNIT * GEMM_INT8_SRC_UNIT * sizeof(int8_t);
|
|
::memset(colAddr, inputZeroPoint, col_buffer_size);
|
|
const int icDiv4 = im2colParameter->icDiv4;
|
|
const int srcZStep = im2colParameter->iw * im2colParameter->ih * GEMM_INT8_UNIT;
|
|
inputOrigin += xIndexStart * GEMM_INT8_UNIT;
|
|
for (int i = 0; i < realDstCount; ++i) {
|
|
auto colAddrI = colAddr + GEMM_INT8_UNIT * i;
|
|
auto inputK = inputOrigin + GEMM_INT8_UNIT * i;
|
|
for (int sz = 0; sz < icDiv4; ++sz) {
|
|
auto inputZ0 = inputK + srcZStep * sz;
|
|
auto dstK0 = colAddrI + sz * GEMM_INT8_UNIT * GEMM_INT8_DST_XUNIT;
|
|
*((int32_t*)dstK0) = *((int32_t*)inputZ0);
|
|
}
|
|
}
|
|
}
|
|
|
|
static MNN::CoreInt8Functions::Im2ColFunc chooseIm2ColSdot(const MNN::ConvolutionCommon::Im2ColParameter* im2colParam, size_t inputChannel) {
|
|
bool fastIm2Col = im2colParam->kernelX == 1 && im2colParam->kernelY == 1 && im2colParam->icDiv4 % 2 == 0 &&
|
|
im2colParam->strideX == 1 && im2colParam->strideY == 1 && im2colParam->padX == 0 &&
|
|
im2colParam->padY == 0;
|
|
int ih = im2colParam->ih, iw = im2colParam->iw;
|
|
fastIm2Col &= (im2colParam->srcYStep == iw * GEMM_INT8_UNIT && im2colParam->srcZStep == ih * iw * GEMM_INT8_UNIT);
|
|
if (fastIm2Col) {
|
|
return _fastIm2ColSdot;
|
|
} else {
|
|
return _im2colCommonSdot;
|
|
}
|
|
}
|
|
|
|
static void MNNGetGemmUnitSdot(int* UNIT, int* SRC_UNIT, int* DST_XUNIT) {
|
|
*UNIT = GEMM_INT8_UNIT;
|
|
*SRC_UNIT = GEMM_INT8_SRC_UNIT;
|
|
*DST_XUNIT = GEMM_INT8_DST_XUNIT;
|
|
}
|
|
|
|
#undef GEMM_INT8_UNIT
|
|
#undef GEMM_INT8_SRC_UNIT
|
|
#undef GEMM_INT8_DST_XUNIT
|
|
/* End */
|
|
|
|
|
|
/* CPU with i8mm */
|
|
#define GEMM_INT8_UNIT 4
|
|
#define GEMM_INT8_SRC_UNIT 8
|
|
#define GEMM_INT8_DST_XUNIT 20
|
|
|
|
// icDiv4 % 2 == 0 will call this function
|
|
static void _im2colCommonI8mm(int8_t* colAddr, const int8_t* src, int32_t inputZeroPoint,
|
|
const MNN::ConvolutionCommon::Im2ColParameter* im2colParameter, size_t xIndexStart,
|
|
size_t realDstCount) {
|
|
const int col_buffer_size = im2colParameter->kernelCountUnit * GEMM_INT8_DST_XUNIT * GEMM_INT8_SRC_UNIT * sizeof(int8_t);
|
|
::memset(colAddr, inputZeroPoint, col_buffer_size); // the padding process, since per-channel is removed, this is all right
|
|
auto ih = im2colParameter->ih;
|
|
auto iw = im2colParameter->iw;
|
|
auto kh = im2colParameter->kernelY;
|
|
auto kw = im2colParameter->kernelX;
|
|
auto dilateX = im2colParameter->dilateX;
|
|
auto dilateY = im2colParameter->dilateY;
|
|
auto icDiv4 = im2colParameter->icDiv4;
|
|
auto srcZStep = im2colParameter->srcZStep;
|
|
auto srcYStep = im2colParameter->srcYStep;
|
|
constexpr int dstXStepInt32 = GEMM_INT8_SRC_UNIT * GEMM_INT8_DST_XUNIT / sizeof(int32_t);
|
|
constexpr int SRC_DIV_UNIT = GEMM_INT8_SRC_UNIT / GEMM_INT8_UNIT; // 2
|
|
auto icDiv8 = icDiv4 / 2;
|
|
for (int i = 0; i < realDstCount; ++i) {
|
|
int xIndex = (int)xIndexStart + i;
|
|
int ox = xIndex % im2colParameter->ow;
|
|
int oy = xIndex / im2colParameter->ow;
|
|
int sx = ox * im2colParameter->strideX - im2colParameter->padX;
|
|
int sy = oy * im2colParameter->strideY - im2colParameter->padY;
|
|
int sfy = ALIMAX(0, (UP_DIV(-sy, im2colParameter->dilateY)));
|
|
int efy = ALIMIN(kh, UP_DIV(ih - sy, im2colParameter->dilateY));
|
|
int sfx = ALIMAX(0, (UP_DIV(-sx, im2colParameter->dilateX)));
|
|
int efx = ALIMIN(kw, UP_DIV(iw - sx, im2colParameter->dilateX));
|
|
int fyC = efy - sfy;
|
|
int fxC = efx - sfx;
|
|
auto colAddrI = colAddr + GEMM_INT8_SRC_UNIT * i;
|
|
auto inputOffset = src + (sy + sfy * dilateY) * srcYStep + (sx + sfx * dilateX) * GEMM_INT8_UNIT;
|
|
auto indexOffset = (sfy * kw + sfx) * icDiv8;
|
|
for (int fy = 0; fy < fyC; ++fy) {
|
|
for (int fx = 0; fx < fxC; ++fx) {
|
|
auto inputK = inputOffset + fy * dilateY * srcYStep + fx * dilateX * GEMM_INT8_UNIT;
|
|
auto indexStart = indexOffset + (fy * kw + fx) * icDiv8;
|
|
for (int sz = 0; sz < icDiv8; ++sz) {
|
|
const int yIndex = indexStart + sz;
|
|
auto dstK0 = (int32_t*)colAddrI + yIndex * dstXStepInt32;
|
|
dstK0[0] = *((int32_t*)inputK);
|
|
dstK0[1] = *((int32_t*)(inputK + srcZStep));
|
|
inputK += 2 * srcZStep;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
static void _slowIm2ColI8mm(int8_t* colAddr, const int8_t* src, int32_t inputZeroPoint,
|
|
const MNN::ConvolutionCommon::Im2ColParameter* im2colParameter, size_t xIndexStart,
|
|
size_t realDstCount) {
|
|
const int col_buffer_size = im2colParameter->kernelCountUnit * GEMM_INT8_DST_XUNIT * GEMM_INT8_SRC_UNIT * sizeof(int8_t);
|
|
::memset(colAddr, inputZeroPoint, col_buffer_size); // the padding process, since per-channel is removed, this is all right
|
|
auto ih = im2colParameter->ih;
|
|
auto iw = im2colParameter->iw;
|
|
auto kh = im2colParameter->kernelY;
|
|
auto kw = im2colParameter->kernelX;
|
|
auto dilateX = im2colParameter->dilateX;
|
|
auto dilateY = im2colParameter->dilateY;
|
|
auto icDiv4 = im2colParameter->icDiv4;
|
|
auto srcZStep = im2colParameter->srcZStep;
|
|
auto srcYStep = im2colParameter->srcYStep;
|
|
constexpr int dstXStepInt32 = GEMM_INT8_SRC_UNIT * GEMM_INT8_DST_XUNIT / sizeof(int32_t);
|
|
constexpr int SRC_DIV_UNIT = GEMM_INT8_SRC_UNIT / GEMM_INT8_UNIT;
|
|
|
|
for (int i = 0; i < realDstCount; ++i) {
|
|
int xIndex = (int)xIndexStart + i;
|
|
int ox = xIndex % im2colParameter->ow;
|
|
int oy = xIndex / im2colParameter->ow;
|
|
int sx = ox * im2colParameter->strideX - im2colParameter->padX;
|
|
int sy = oy * im2colParameter->strideY - im2colParameter->padY;
|
|
int sfy = ALIMAX(0, (UP_DIV(-sy, im2colParameter->dilateY)));
|
|
int efy = ALIMIN(kh, UP_DIV(ih - sy, im2colParameter->dilateY));
|
|
int sfx = ALIMAX(0, (UP_DIV(-sx, im2colParameter->dilateX)));
|
|
int efx = ALIMIN(kw, UP_DIV(iw - sx, im2colParameter->dilateX));
|
|
int fyC = efy - sfy;
|
|
int fxC = efx - sfx;
|
|
auto colAddrI = colAddr + GEMM_INT8_SRC_UNIT * i;
|
|
auto inputOffset = src + (sy + sfy * dilateY) * srcYStep + (sx + sfx * dilateX) * GEMM_INT8_UNIT;
|
|
auto indexOffset = (sfy * kw + sfx) * icDiv4;
|
|
for (int fy = 0; fy < fyC; ++fy) {
|
|
for (int fx = 0; fx < fxC; ++fx) {
|
|
auto inputK = inputOffset + fy * dilateY * srcYStep + fx * dilateX * GEMM_INT8_UNIT;
|
|
auto indexStart = indexOffset + (fy * kw + fx) * icDiv4;
|
|
for (int sz = 0; sz < icDiv4; ++sz) {
|
|
const int yIndex = indexStart + sz;
|
|
const int ySubOutside = yIndex / SRC_DIV_UNIT;
|
|
const int ySubInside = yIndex % SRC_DIV_UNIT;
|
|
auto dstK0 = (int32_t*)colAddrI + ySubOutside * dstXStepInt32 + ySubInside;
|
|
dstK0[0] = *((int32_t*)inputK);
|
|
inputK += srcZStep;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
static void _fastIm2ColI8mm(int8_t* colAddr, const int8_t* inputOrigin, int32_t inputZeroPoint,
|
|
const MNN::ConvolutionCommon::Im2ColParameter* im2colParameter, size_t xIndexStart,
|
|
size_t realDstCount) {
|
|
const int col_buffer_size = im2colParameter->kernelCountUnit * GEMM_INT8_DST_XUNIT * GEMM_INT8_SRC_UNIT * sizeof(int8_t);
|
|
::memset(colAddr, inputZeroPoint, col_buffer_size);
|
|
const int icDiv8 = im2colParameter->icDiv4 / 2;
|
|
const int srcZStep = im2colParameter->srcZStep;
|
|
constexpr int dstXStepInt32 = GEMM_INT8_SRC_UNIT * GEMM_INT8_DST_XUNIT / sizeof(int32_t);
|
|
inputOrigin += xIndexStart * GEMM_INT8_UNIT;
|
|
for (int i = 0; i < realDstCount; ++i) {
|
|
auto colAddrI = colAddr + GEMM_INT8_SRC_UNIT * i;
|
|
auto inputK = inputOrigin + GEMM_INT8_UNIT * i;
|
|
for (int sz = 0; sz < icDiv8; ++sz) {
|
|
auto inputZ0 = inputK + srcZStep * sz * 2;
|
|
auto dstK0 = (int32_t*)colAddrI + sz * dstXStepInt32;
|
|
dstK0[0] = *((int32_t*)inputZ0);
|
|
dstK0[1] = *((int32_t*)(inputZ0 + srcZStep));
|
|
}
|
|
}
|
|
}
|
|
|
|
static MNN::CoreInt8Functions::Im2ColFunc chooseIm2ColI8mm(const MNN::ConvolutionCommon::Im2ColParameter* im2colParam, size_t inputChannel) {
|
|
bool fastIm2Col = im2colParam->kernelX == 1 && im2colParam->kernelY == 1 && im2colParam->icDiv4 % 2 == 0 &&
|
|
im2colParam->strideX == 1 && im2colParam->strideY == 1 && im2colParam->padX == 0 &&
|
|
im2colParam->padY == 0;
|
|
int ih = im2colParam->ih, iw = im2colParam->iw;
|
|
fastIm2Col &= (im2colParam->srcYStep == iw * GEMM_INT8_UNIT && im2colParam->srcZStep == ih * iw * GEMM_INT8_UNIT);
|
|
if (fastIm2Col) {
|
|
return _fastIm2ColI8mm;
|
|
} else {
|
|
if (im2colParam->icDiv4 % 2) {
|
|
return _slowIm2ColI8mm;
|
|
} else {
|
|
return _im2colCommonI8mm;
|
|
}
|
|
}
|
|
}
|
|
|
|
static void MNNGetGemmUnitI8mm(int* UNIT, int* SRC_UNIT, int* DST_XUNIT) {
|
|
*UNIT = GEMM_INT8_UNIT;
|
|
*SRC_UNIT = GEMM_INT8_SRC_UNIT;
|
|
*DST_XUNIT = GEMM_INT8_DST_XUNIT;
|
|
}
|
|
#undef GEMM_INT8_UNIT
|
|
#undef GEMM_INT8_SRC_UNIT
|
|
#undef GEMM_INT8_DST_XUNIT
|
|
/* End */
|
|
|
|
namespace MNN {
|
|
|
|
static CoreInt8Functions* gCoreFunc = nullptr;
|
|
|
|
void MNNCoreInt8FunctionInit() {
|
|
/* CoreInt8Functions without sdot */
|
|
gCoreFunc = new CoreInt8Functions;
|
|
|
|
// MatMul
|
|
gCoreFunc->Int8GemmKernel = MNNGemmInt8AddBiasScale_16x4_Unit;
|
|
gCoreFunc->Int8GemmKernelFast = MNNGemmInt8AddBiasScale_16x4_Unit_FAST;
|
|
gCoreFunc->MNNGetGemmUnit = MNNGetGemmUnit;
|
|
|
|
// Im2Col
|
|
gCoreFunc->chooseIm2Col = chooseIm2Col;
|
|
// conv depthwise
|
|
gCoreFunc->ConvDepthwiseLineInt8 = MNNLineDepthWiseInt8AddBiasScaleUnit;
|
|
gCoreFunc->MNNFloat2Int8 = MNNFloat2Int8;
|
|
gCoreFunc->MNNInt8ScaleToFloat = MNNInt8ScaleToFloat;
|
|
|
|
// sparse
|
|
gCoreFunc->MNNGetSparseQuantMatMulPackMode = MNNGetSparseQuantMatMulPackMode;
|
|
gCoreFunc->MNNPackForSparseQuantMatMul_B = MNNPackForSparseQuantMatMul_B;
|
|
gCoreFunc->MNNPackedSparseQuantMatMulEpx1 = MNNPackedSparseQuantMatMulEpx1;
|
|
gCoreFunc->MNNPackedSparseQuantMatMulEpx4 = MNNPackedSparseQuantMatMulEpx4;
|
|
gCoreFunc->MNNSparseQuantIm2col = MNNSparseQuantIm2col;
|
|
|
|
// pooling
|
|
gCoreFunc->MNNAvgPoolInt8 = MNNAvgPoolInt8;
|
|
gCoreFunc->MNNMaxPoolInt8 = MNNMaxPoolInt8;
|
|
|
|
#if defined(__aarch64__)
|
|
auto core = MNNGetCoreFunctions();
|
|
if (core->supportSDot) {
|
|
// MatMul
|
|
gCoreFunc->Int8GemmKernel = MNNGemmInt8AddBiasScale_ARMV82_Unit;
|
|
gCoreFunc->Int8GemmKernelFast = MNNGemmInt8AddBiasScale_ARMV82_Unit;
|
|
gCoreFunc->MNNGetGemmUnit = MNNGetGemmUnitSdot;
|
|
// Im2Col
|
|
gCoreFunc->chooseIm2Col = chooseIm2ColSdot;
|
|
}
|
|
if (core->supportI8mm) {
|
|
// MatMul
|
|
gCoreFunc->Int8GemmKernel = MNNGemmInt8AddBiasScale_ARMV86_Unit;
|
|
gCoreFunc->Int8GemmKernelFast = MNNGemmInt8AddBiasScale_ARMV86_Unit;
|
|
gCoreFunc->MNNGetGemmUnit = MNNGetGemmUnitI8mm;
|
|
// Im2Col
|
|
gCoreFunc->chooseIm2Col = chooseIm2ColI8mm;
|
|
}
|
|
#endif
|
|
MNNInt8FunctionInit();
|
|
}
|
|
CoreInt8Functions* MNNGetInt8CoreFunctions() {
|
|
return gCoreFunc;
|
|
}
|
|
};
|