mirror of https://github.com/alibaba/MNN.git
				
				
				
			
		
			
				
	
	
		
			429 lines
		
	
	
		
			20 KiB
		
	
	
	
		
			C++
		
	
	
	
			
		
		
	
	
			429 lines
		
	
	
		
			20 KiB
		
	
	
	
		
			C++
		
	
	
	
//
 | 
						|
//  Int8FunctionsOpt.cpp
 | 
						|
//  MNN
 | 
						|
//
 | 
						|
//  Created by MNN on 2018/08/15.
 | 
						|
//  Copyright © 2018, Alibaba Group Holding Limited
 | 
						|
//
 | 
						|
 | 
						|
#include <math.h>
 | 
						|
#include <cstring> // for memset
 | 
						|
#include "Int8FunctionsOpt.h"
 | 
						|
#include "core/Macro.h"
 | 
						|
#include "CommonOptFunction.h"
 | 
						|
 | 
						|
#ifdef MNN_USE_NEON
 | 
						|
#include <arm_neon.h>
 | 
						|
 | 
						|
extern "C" {
 | 
						|
void MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad,
 | 
						|
                                       const QuanPostTreatParameters* post, size_t realCount);
 | 
						|
void MNNGemmInt8AddBiasScale_16x4_Unit_FAST(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad,
 | 
						|
                                            const QuanPostTreatParameters* post, size_t realCount);
 | 
						|
void MNNLineDepthWiseInt8AddBiasScaleUnit(int8_t* dst, const int8_t* src, const int8_t* weight, const QuanPostTreatParameters* parameters, size_t width,
 | 
						|
                                          size_t src_w_step, size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step);
 | 
						|
#if defined(__aarch64__) && defined(MNN_USE_ARMV82) // aarch32 sdot workaround
 | 
						|
void MNNGemmInt8AddBiasScale_ARMV82_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad,
 | 
						|
                                         const QuanPostTreatParameters* post, size_t realDstCount);
 | 
						|
#endif // __aarch64__ && MNN_USE_ARMV82
 | 
						|
}
 | 
						|
#endif // MNN_USE_NEON
 | 
						|
 | 
						|
#ifndef MNN_USE_NEON
 | 
						|
static int8_t MNNInt32ToInt8(int data, int bias, float scale, float maxValue, float minValue)
 | 
						|
{
 | 
						|
    float value = (float)(data + bias) * scale;
 | 
						|
    value       = ALIMAX(value, minValue);
 | 
						|
    value       = ALIMIN(value, maxValue);
 | 
						|
    return static_cast<int8_t>(roundf(value));
 | 
						|
}
 | 
						|
 | 
						|
static void MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step,
 | 
						|
                                              size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realCount) {
 | 
						|
    const auto dst_step_tmp = dst_step / sizeof(int8_t);
 | 
						|
    for (int dz = 0; dz < dst_depth_quad; ++dz) {
 | 
						|
        const auto weight_dz = weight + dz * src_depth_quad * (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT);
 | 
						|
        const auto bias_dz   = post->bias + dz * GEMM_INT8_UNIT;
 | 
						|
        const auto scale_dz  = post->scale + dz * GEMM_INT8_UNIT;
 | 
						|
        auto dst_z           = dst + dz * dst_step_tmp;
 | 
						|
        for (int w = 0; w < realCount; ++w) {
 | 
						|
            const auto src_x   = src + w * GEMM_INT8_SRC_UNIT;
 | 
						|
            auto dst_x         = dst_z + w * GEMM_INT8_UNIT;
 | 
						|
            int32_t dstTemp[4] = {0, 0, 0, 0};
 | 
						|
 | 
						|
            for (int sz = 0; sz < src_depth_quad; ++sz) {
 | 
						|
                const auto weight_sz = weight_dz + (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT) * sz;
 | 
						|
                const auto src_z     = src_x + sz * GEMM_INT8_DST_XUNIT * GEMM_INT8_SRC_UNIT;
 | 
						|
 | 
						|
                for (int j = 0; j < GEMM_INT8_UNIT; ++j) {
 | 
						|
                    const auto weight_j = weight_sz + j * GEMM_INT8_SRC_UNIT;
 | 
						|
                    for (int i = 0; i < GEMM_INT8_SRC_UNIT; ++i) {
 | 
						|
                        dstTemp[j] += (int32_t)src_z[i] * (int32_t)weight_j[i];
 | 
						|
                    }
 | 
						|
                }
 | 
						|
            }
 | 
						|
 | 
						|
            for (int j = 0; j < 4; ++j) {
 | 
						|
                if (post != nullptr) {
 | 
						|
                    dst_x[j] = MNNInt32ToInt8(dstTemp[j], bias_dz[j], scale_dz[j], post->maxValue, post->minValue);
 | 
						|
                } else {
 | 
						|
                     ((float*)dst_x)[j] = (float)(dstTemp[j] + bias_dz[j]);
 | 
						|
                }
 | 
						|
            }
 | 
						|
        }
 | 
						|
    }
 | 
						|
}
 | 
						|
 | 
						|
static void MNNGemmInt8AddBiasScale_16x4_Unit_FAST(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realCount) {
 | 
						|
    return MNNGemmInt8AddBiasScale_16x4_Unit(dst, src, weight, src_depth_quad, dst_step, dst_depth_quad, post, realCount);
 | 
						|
}
 | 
						|
 | 
						|
static void MNNLineDepthWiseInt8AddBiasScaleUnit(int8_t* dst, const int8_t* src, const int8_t* weight, const QuanPostTreatParameters* parameters,
 | 
						|
                                          size_t width, size_t src_w_step, size_t fw, size_t fh, size_t dilateX_step,
 | 
						|
                                          size_t dilateY_step) {
 | 
						|
    auto bias_z = parameters->bias;
 | 
						|
    auto scale_z = parameters->scale;
 | 
						|
    int dx, fx, fy;
 | 
						|
    for (dx = 0; dx < width; ++dx) {
 | 
						|
        auto dst_x          = dst + dx * 4;
 | 
						|
        int32_t dstInt32[4] = {0, 0, 0, 0};
 | 
						|
        const auto src_z    = src + src_w_step * dx;
 | 
						|
        for (fy = 0; fy < fh; ++fy) {
 | 
						|
            const auto src_y    = src_z + fy * dilateY_step;
 | 
						|
            const auto weight_y = weight + fy * fw * 4;
 | 
						|
            for (fx = 0; fx < fw; ++fx) {
 | 
						|
                const auto src_x    = src_y + fx * dilateX_step;
 | 
						|
                const auto weight_x = weight_y + 4 * fx;
 | 
						|
                for (int j = 0; j < GEMM_INT8_UNIT; ++j) {
 | 
						|
                    dstInt32[j] += (int32_t)src_x[j] * (int32_t)weight_x[j];
 | 
						|
                }
 | 
						|
            }
 | 
						|
        }
 | 
						|
 | 
						|
        for (int i = 0; i < GEMM_INT8_UNIT; ++i) {
 | 
						|
            dst_x[i] = MNNInt32ToInt8(dstInt32[i], bias_z[i], scale_z[i], parameters->maxValue, parameters->minValue);
 | 
						|
        }
 | 
						|
    }
 | 
						|
}
 | 
						|
#endif
 | 
						|
 | 
						|
#ifndef MNN_USE_SSE
 | 
						|
void MNNInt8FunctionInit() {
 | 
						|
    // do nothing
 | 
						|
}
 | 
						|
#ifndef MNN_USE_NEON
 | 
						|
void MNNFloat2Int8(const float* src, int8_t* dst, size_t sizeQuad, const float* scalep, ssize_t minValue,
 | 
						|
                   ssize_t maxValue, ssize_t zeroPoint) {
 | 
						|
    for (int i = 0; i < sizeQuad; ++i) {
 | 
						|
        for (int j=0; j<4; ++j) {
 | 
						|
            int v = (int)roundf(src[4*i+j] * scalep[j]) + zeroPoint;
 | 
						|
            if (v > maxValue) {
 | 
						|
                v = maxValue;
 | 
						|
            }
 | 
						|
            if (v < minValue) {
 | 
						|
                v = minValue;
 | 
						|
            }
 | 
						|
            dst[4*i+j] = v;
 | 
						|
        }
 | 
						|
    }
 | 
						|
}
 | 
						|
 | 
						|
void MNNInt8ScaleToFloat(float* dst, const int8_t* src, const float* scale, size_t size, ssize_t zeroPoint) {
 | 
						|
    for (int i = 0; i < size; ++i) {
 | 
						|
        const auto srcStart = src + i * 4;
 | 
						|
        auto dstStart       = dst + i * 4;
 | 
						|
        for (int j = 0; j < 4; ++j) {
 | 
						|
            dstStart[j] = static_cast<float>(srcStart[j] - zeroPoint) * scale[j];
 | 
						|
        }
 | 
						|
    }
 | 
						|
}
 | 
						|
#endif // #ifndef MNN_USE_NEON
 | 
						|
#endif // #ifndef MNN_USE_SSE
 | 
						|
 | 
						|
/* CPU without sdot */
 | 
						|
// Assume GEMM_INT8_UNIT == 4 && GEMM_INT8_SRC_UNIT == 16
 | 
						|
static void _fastIm2Col(int8_t* colAddr, const int8_t* inputOrigin, int8_t inputZeroPoint,
 | 
						|
                        const MNN::ConvolutionCommon::Im2ColParameter* im2colParameter, size_t xIndexStart,
 | 
						|
                        size_t realDstCount) {
 | 
						|
    const int col_buffer_size = im2colParameter->kernelCountUnit * GEMM_INT8_SRC_UNIT * GEMM_INT8_DST_XUNIT * sizeof(int8_t);
 | 
						|
    ::memset(colAddr, inputZeroPoint, col_buffer_size); // the padding process, since per-channel is removed, this is all right
 | 
						|
 | 
						|
    const int icDiv8   = im2colParameter->icDiv4 / 2;
 | 
						|
    const int srcZStep = im2colParameter->iw * im2colParameter->ih * GEMM_INT8_UNIT;
 | 
						|
    inputOrigin += xIndexStart * GEMM_INT8_UNIT;
 | 
						|
    for (int i = 0; i < realDstCount; ++i) {
 | 
						|
        auto colAddrI = colAddr + GEMM_INT8_SRC_UNIT * i;
 | 
						|
        auto inputK   = inputOrigin + GEMM_INT8_UNIT * i;
 | 
						|
        for (int sz = 0; sz < icDiv8; ++sz) {
 | 
						|
            auto inputZ0           = inputK + srcZStep * (2 * sz + 0);
 | 
						|
            auto inputZ1           = inputK + srcZStep * (2 * sz + 1);
 | 
						|
            const int indexOutside = sz / 2;
 | 
						|
            const int indexInsize  = sz % 2;
 | 
						|
 | 
						|
            auto dstK0         = colAddrI + (indexOutside * GEMM_INT8_DST_XUNIT * 2 + indexInsize) * (2 * GEMM_INT8_UNIT);
 | 
						|
            auto dstK1         = dstK0 + GEMM_INT8_UNIT;
 | 
						|
            *((int32_t*)dstK0) = *((int32_t*)inputZ0);
 | 
						|
            *((int32_t*)dstK1) = *((int32_t*)inputZ1);
 | 
						|
        }
 | 
						|
    }
 | 
						|
}
 | 
						|
 | 
						|
static void _im2colCommonZ1(int8_t* colAddr, const int8_t* inputOrigin, int8_t inputZeroPoint,
 | 
						|
                            const MNN::ConvolutionCommon::Im2ColParameter* im2colParameter, size_t xIndexStart,
 | 
						|
                            size_t realDstCount) {
 | 
						|
    int col_buffer_size = im2colParameter->kernelCountUnit * GEMM_INT8_DST_XUNIT * GEMM_INT8_SRC_UNIT * sizeof(int8_t);
 | 
						|
    ::memset(colAddr, inputZeroPoint, col_buffer_size); // the padding process, since per-channel is removed, this is all right
 | 
						|
 | 
						|
    auto ih                     = im2colParameter->ih;
 | 
						|
    auto iw                     = im2colParameter->iw;
 | 
						|
    auto kh                     = im2colParameter->kernelY;
 | 
						|
    auto kw                     = im2colParameter->kernelX;
 | 
						|
    auto dilateX                = im2colParameter->dilateX;
 | 
						|
    auto dilateY                = im2colParameter->dilateY;
 | 
						|
    auto srcYStep               = im2colParameter->srcYStep;
 | 
						|
    constexpr int dstXStepInt32 = GEMM_INT8_SRC_UNIT * GEMM_INT8_DST_XUNIT / sizeof(int32_t);
 | 
						|
    for (int i = 0; i < realDstCount; ++i) {
 | 
						|
        int xIndex = (int)xIndexStart + i;
 | 
						|
        int ox     = xIndex % im2colParameter->ow;
 | 
						|
        int oy     = xIndex / im2colParameter->ow;
 | 
						|
 | 
						|
        int sx = ox * im2colParameter->strideX - im2colParameter->padX;
 | 
						|
        int sy = oy * im2colParameter->strideY - im2colParameter->padY;
 | 
						|
 | 
						|
        int sfy = ALIMAX(0, (UP_DIV(-sy, im2colParameter->dilateY)));
 | 
						|
        int efy = ALIMIN(kh, UP_DIV(ih - sy, im2colParameter->dilateY));
 | 
						|
        int sfx = ALIMAX(0, (UP_DIV(-sx, im2colParameter->dilateX)));
 | 
						|
        int efx = ALIMIN(kw, UP_DIV(iw - sx, im2colParameter->dilateX));
 | 
						|
        int fyC = efy - sfy;
 | 
						|
        int fxC = efx - sfx;
 | 
						|
 | 
						|
        auto colAddrI    = colAddr + GEMM_INT8_SRC_UNIT * i;
 | 
						|
        
 | 
						|
        auto inputOffset = inputOrigin + (sy + sfy * dilateY) * srcYStep + (sx + sfx * dilateX) * GEMM_INT8_UNIT;
 | 
						|
        auto indexOffset = sfy * kw + sfx;
 | 
						|
        for (int fy = 0; fy < fyC; ++fy) {
 | 
						|
            for (int fx = 0; fx < fxC; ++fx) {
 | 
						|
                auto inputK       = inputOffset + fy * dilateY * srcYStep + fx * dilateX * GEMM_INT8_UNIT;
 | 
						|
                auto indexStart   = indexOffset + fy * kw + fx;
 | 
						|
                auto indexInside  = indexStart % 4;
 | 
						|
                auto indexOutside = indexStart / 4;
 | 
						|
                auto dstK0        = (int32_t*)colAddrI + indexOutside * dstXStepInt32 + indexInside;
 | 
						|
                dstK0[0]          = *((int32_t*)inputK);
 | 
						|
            }
 | 
						|
        }
 | 
						|
    }
 | 
						|
}
 | 
						|
 | 
						|
static void _im2colCommon(int8_t* colAddr, const int8_t* inputOrigin, int8_t inputZeroPoint,
 | 
						|
                          const MNN::ConvolutionCommon::Im2ColParameter* im2colParameter, size_t xIndexStart,
 | 
						|
                          size_t realDstCount) {
 | 
						|
    const int col_buffer_size = im2colParameter->kernelCountUnit * GEMM_INT8_DST_XUNIT * GEMM_INT8_SRC_UNIT * sizeof(int8_t);
 | 
						|
    ::memset(colAddr, inputZeroPoint, col_buffer_size); // the padding process, since per-channel is removed, this is all right
 | 
						|
 | 
						|
    auto ih                     = im2colParameter->ih;
 | 
						|
    auto iw                     = im2colParameter->iw;
 | 
						|
    auto kh                     = im2colParameter->kernelY;
 | 
						|
    auto kw                     = im2colParameter->kernelX;
 | 
						|
    auto dilateX                = im2colParameter->dilateX;
 | 
						|
    auto dilateY                = im2colParameter->dilateY;
 | 
						|
    auto icDiv4                 = im2colParameter->icDiv4;
 | 
						|
    auto srcZStep               = im2colParameter->srcZStep;
 | 
						|
    auto srcYStep               = im2colParameter->srcYStep;
 | 
						|
    constexpr int dstXStepInt32 = GEMM_INT8_SRC_UNIT * GEMM_INT8_DST_XUNIT / sizeof(int32_t);
 | 
						|
    for (int i = 0; i < realDstCount; ++i) {
 | 
						|
        int xIndex = (int)xIndexStart + i;
 | 
						|
        int ox     = xIndex % im2colParameter->ow;
 | 
						|
        int oy     = xIndex / im2colParameter->ow;
 | 
						|
 | 
						|
        int sx = ox * im2colParameter->strideX - im2colParameter->padX;
 | 
						|
        int sy = oy * im2colParameter->strideY - im2colParameter->padY;
 | 
						|
 | 
						|
        int sfy = ALIMAX(0, (UP_DIV(-sy, im2colParameter->dilateY)));
 | 
						|
        int efy = ALIMIN(kh, UP_DIV(ih - sy, im2colParameter->dilateY));
 | 
						|
        int sfx = ALIMAX(0, (UP_DIV(-sx, im2colParameter->dilateX)));
 | 
						|
        int efx = ALIMIN(kw, UP_DIV(iw - sx, im2colParameter->dilateX));
 | 
						|
        int fyC = efy - sfy;
 | 
						|
        int fxC = efx - sfx;
 | 
						|
 | 
						|
        auto colAddrI    = colAddr + GEMM_INT8_SRC_UNIT * i;
 | 
						|
        
 | 
						|
        auto inputOffset = inputOrigin + (sy + sfy * dilateY) * srcYStep + (sx + sfx * dilateX) * GEMM_INT8_UNIT;
 | 
						|
        auto indexOffset = (sfy * kw + sfx) * icDiv4;
 | 
						|
        for (int fy = 0; fy < fyC; ++fy) {
 | 
						|
            for (int fx = 0; fx < fxC; ++fx) {
 | 
						|
                auto inputK     = inputOffset + fy * dilateY * srcYStep + fx * dilateX * GEMM_INT8_UNIT;
 | 
						|
                auto indexStart = indexOffset + (fy * kw + fx) * icDiv4;
 | 
						|
                for (int sz = 0; sz < icDiv4; ++sz) {
 | 
						|
                    const int yIndex      = indexStart + sz;
 | 
						|
                    const int ySubOutside = yIndex / GEMM_INT8_UNIT;
 | 
						|
                    const int ySubInside  = yIndex % GEMM_INT8_UNIT;
 | 
						|
                    auto dstK0            = (int32_t*)colAddrI + ySubOutside * dstXStepInt32 + ySubInside;
 | 
						|
                    dstK0[0]              = *((int32_t*)inputK);
 | 
						|
                    inputK += srcZStep;
 | 
						|
                }
 | 
						|
            }
 | 
						|
        }
 | 
						|
    }
 | 
						|
}
 | 
						|
 | 
						|
static MNN::CoreInt8Functions::Im2ColFunc chooseIm2Col(const MNN::ConvolutionCommon::Im2ColParameter* im2colParam, size_t inputChannel) {
 | 
						|
    bool fastIm2Col = im2colParam->kernelX == 1 && im2colParam->kernelY == 1 && im2colParam->icDiv4 % 2 == 0 &&
 | 
						|
                      im2colParam->strideX == 1 && im2colParam->strideY == 1 && im2colParam->padX == 0 &&
 | 
						|
                      im2colParam->padY == 0;
 | 
						|
    int ih = im2colParam->ih, iw = im2colParam->iw;
 | 
						|
    fastIm2Col &= (im2colParam->srcYStep == iw * GEMM_INT8_UNIT && im2colParam->srcZStep == ih * iw * GEMM_INT8_UNIT);
 | 
						|
    if (fastIm2Col) {
 | 
						|
        return _fastIm2Col;
 | 
						|
    } else if (inputChannel <= 4) {
 | 
						|
        return _im2colCommonZ1;
 | 
						|
    } else {
 | 
						|
        return _im2colCommon;
 | 
						|
    }
 | 
						|
}
 | 
						|
 | 
						|
static void MNNGetGemmUnit(int* UNIT, int* SRC_UNIT, int* DST_XUNIT) {
 | 
						|
    *UNIT = GEMM_INT8_UNIT;
 | 
						|
    *SRC_UNIT = GEMM_INT8_SRC_UNIT;
 | 
						|
    *DST_XUNIT = GEMM_INT8_DST_XUNIT;
 | 
						|
}
 | 
						|
#undef GEMM_INT8_UNIT
 | 
						|
#undef GEMM_INT8_SRC_UNIT
 | 
						|
#undef GEMM_INT8_DST_XUNIT
 | 
						|
/* End */
 | 
						|
 | 
						|
/* CPU with sdot */
 | 
						|
#define GEMM_INT8_UNIT 4
 | 
						|
#define GEMM_INT8_SRC_UNIT 4
 | 
						|
 | 
						|
#ifdef __aarch64__
 | 
						|
#define GEMM_INT8_DST_XUNIT 12
 | 
						|
#else
 | 
						|
#define GEMM_INT8_DST_XUNIT 8
 | 
						|
#endif
 | 
						|
 | 
						|
static void _im2colCommonSdot(int8_t* colAddr, const int8_t* src, int8_t inputZeroPoint,
 | 
						|
                                const MNN::ConvolutionCommon::Im2ColParameter* im2colParameter, size_t xIndexStart,
 | 
						|
                                size_t realDstCount) {
 | 
						|
    const int colBufferSize = im2colParameter->kernelCountUnit * GEMM_INT8_DST_XUNIT * GEMM_INT8_SRC_UNIT * sizeof(int8_t);
 | 
						|
    memset(colAddr, inputZeroPoint, colBufferSize);
 | 
						|
    auto ih = im2colParameter->ih;
 | 
						|
    auto iw = im2colParameter->iw;
 | 
						|
    // auto oh = im2colParameter->oh;
 | 
						|
    auto ow                     = im2colParameter->ow;
 | 
						|
    auto kh                     = im2colParameter->kernelY;
 | 
						|
    auto kw                     = im2colParameter->kernelX;
 | 
						|
    auto dilateX                = im2colParameter->dilateX;
 | 
						|
    auto dilateY                = im2colParameter->dilateY;
 | 
						|
    auto icDiv4                 = im2colParameter->icDiv4;
 | 
						|
    auto srcChannleStride       = im2colParameter->srcZStep;
 | 
						|
    auto srcYStep               = im2colParameter->srcYStep;
 | 
						|
    constexpr int dstXStepInt32 = GEMM_INT8_UNIT * GEMM_INT8_DST_XUNIT / sizeof(int32_t);
 | 
						|
 | 
						|
    for (int i = 0; i < realDstCount; ++i) {
 | 
						|
        int xIndex = (int)xIndexStart + i;
 | 
						|
        int ox     = xIndex % ow;
 | 
						|
        int oy     = xIndex / ow;
 | 
						|
        int sx     = ox * im2colParameter->strideX - im2colParameter->padX;
 | 
						|
        int sy     = oy * im2colParameter->strideY - im2colParameter->padY;
 | 
						|
        int sfy    = ALIMAX(0, (UP_DIV(-sy, im2colParameter->dilateY)));
 | 
						|
        int efy    = ALIMIN(kh, UP_DIV(ih - sy, im2colParameter->dilateY));
 | 
						|
        int sfx    = ALIMAX(0, (UP_DIV(-sx, im2colParameter->dilateX)));
 | 
						|
        int efx    = ALIMIN(kw, UP_DIV(iw - sx, im2colParameter->dilateX));
 | 
						|
        int fyC    = efy - sfy;
 | 
						|
        int fxC    = efx - sfx;
 | 
						|
 | 
						|
        auto colAddrI    = colAddr + GEMM_INT8_UNIT * i;
 | 
						|
        auto inputOffset = src + (sy + sfy * dilateY) * srcYStep + (sx + sfx * dilateX) * GEMM_INT8_UNIT;
 | 
						|
        auto indexOffset = (sfy * kw + sfx) * icDiv4;
 | 
						|
 | 
						|
        for (int fy = 0; fy < fyC; ++fy) {
 | 
						|
            for (int fx = 0; fx < fxC; ++fx) {
 | 
						|
                auto inputK     = inputOffset + fy * dilateY * srcYStep + fx * dilateX * GEMM_INT8_UNIT;
 | 
						|
                auto indexStart = (indexOffset + (fy * kw + fx) * icDiv4) * dstXStepInt32;
 | 
						|
                for (int sz = 0; sz < icDiv4; ++sz) {
 | 
						|
                    auto dstK0 = (int32_t*)colAddrI + indexStart + sz * dstXStepInt32;
 | 
						|
                    dstK0[0]   = *((int32_t*)inputK);
 | 
						|
                    inputK += srcChannleStride;
 | 
						|
                }
 | 
						|
            }
 | 
						|
        }
 | 
						|
    }
 | 
						|
}
 | 
						|
 | 
						|
static void _fastIm2ColSdot(int8_t* colAddr, const int8_t* inputOrigin, int8_t inputZeroPoint,
 | 
						|
                              const MNN::ConvolutionCommon::Im2ColParameter* im2colParameter, size_t xIndexStart,
 | 
						|
                              size_t realDstCount) {
 | 
						|
    const int col_buffer_size = im2colParameter->kernelCountUnit * GEMM_INT8_DST_XUNIT * GEMM_INT8_SRC_UNIT * sizeof(int8_t);
 | 
						|
    ::memset(colAddr, inputZeroPoint, col_buffer_size);
 | 
						|
    const int icDiv4    = im2colParameter->icDiv4;
 | 
						|
    const int srcZStep = im2colParameter->iw * im2colParameter->ih * GEMM_INT8_UNIT;
 | 
						|
    inputOrigin += xIndexStart * GEMM_INT8_UNIT;
 | 
						|
    for (int i = 0; i < realDstCount; ++i) {
 | 
						|
        auto colAddrI = colAddr + GEMM_INT8_UNIT * i;
 | 
						|
        auto inputK   = inputOrigin + GEMM_INT8_UNIT * i;
 | 
						|
        for (int sz = 0; sz < icDiv4; ++sz) {
 | 
						|
            auto inputZ0       = inputK + srcZStep * sz;
 | 
						|
            auto dstK0         = colAddrI + sz * GEMM_INT8_UNIT * GEMM_INT8_DST_XUNIT;
 | 
						|
            *((int32_t*)dstK0) = *((int32_t*)inputZ0);
 | 
						|
        }
 | 
						|
    }
 | 
						|
}
 | 
						|
 | 
						|
static MNN::CoreInt8Functions::Im2ColFunc chooseIm2ColSdot(const MNN::ConvolutionCommon::Im2ColParameter* im2colParam, size_t inputChannel) {
 | 
						|
    bool fastIm2Col = im2colParam->kernelX == 1 && im2colParam->kernelY == 1 && im2colParam->icDiv4 % 2 == 0 &&
 | 
						|
                      im2colParam->strideX == 1 && im2colParam->strideY == 1 && im2colParam->padX == 0 &&
 | 
						|
                      im2colParam->padY == 0;
 | 
						|
    int ih = im2colParam->ih, iw = im2colParam->iw;
 | 
						|
    fastIm2Col &= (im2colParam->srcYStep == iw * GEMM_INT8_UNIT && im2colParam->srcZStep == ih * iw * GEMM_INT8_UNIT);
 | 
						|
    if (fastIm2Col) {
 | 
						|
        return _fastIm2ColSdot;
 | 
						|
    } else {
 | 
						|
        return _im2colCommonSdot;
 | 
						|
    }
 | 
						|
}
 | 
						|
 | 
						|
static void MNNGetGemmUnitSdot(int* UNIT, int* SRC_UNIT, int* DST_XUNIT) {
 | 
						|
    *UNIT = GEMM_INT8_UNIT;
 | 
						|
    *SRC_UNIT = GEMM_INT8_SRC_UNIT;
 | 
						|
    *DST_XUNIT = GEMM_INT8_DST_XUNIT;
 | 
						|
}
 | 
						|
 | 
						|
/* End */
 | 
						|
#undef GEMM_INT8_UNIT
 | 
						|
#undef GEMM_INT8_SRC_UNIT
 | 
						|
#undef GEMM_INT8_DST_XUNIT
 | 
						|
 | 
						|
namespace MNN {
 | 
						|
 | 
						|
static CoreInt8Functions* gCoreFunc = nullptr;
 | 
						|
 | 
						|
void MNNCoreInt8FunctionInit() {
 | 
						|
    /* CoreInt8Functions without sdot */
 | 
						|
    gCoreFunc = new CoreInt8Functions;
 | 
						|
    // MatMul
 | 
						|
    gCoreFunc->Int8GemmKernel = MNNGemmInt8AddBiasScale_16x4_Unit;
 | 
						|
    gCoreFunc->Int8GemmKernelFast = MNNGemmInt8AddBiasScale_16x4_Unit_FAST;
 | 
						|
    gCoreFunc->MNNGetGemmUnit = MNNGetGemmUnit;
 | 
						|
    // Im2Col
 | 
						|
    gCoreFunc->chooseIm2Col = chooseIm2Col;
 | 
						|
    // conv depthwise
 | 
						|
    gCoreFunc->ConvDepthwiseLineInt8 = MNNLineDepthWiseInt8AddBiasScaleUnit;
 | 
						|
 | 
						|
#if defined(__aarch64__) && defined(MNN_USE_ARMV82)
 | 
						|
    auto core = MNNGetCoreFunctions();
 | 
						|
    if (core->supportSDot) {
 | 
						|
        // MatMul
 | 
						|
        gCoreFunc->Int8GemmKernel = MNNGemmInt8AddBiasScale_ARMV82_Unit;
 | 
						|
        gCoreFunc->Int8GemmKernelFast = MNNGemmInt8AddBiasScale_ARMV82_Unit;
 | 
						|
        gCoreFunc->MNNGetGemmUnit = MNNGetGemmUnitSdot;
 | 
						|
        // Im2Col
 | 
						|
        gCoreFunc->chooseIm2Col = chooseIm2ColSdot;
 | 
						|
    }
 | 
						|
#endif
 | 
						|
    MNNInt8FunctionInit();
 | 
						|
}
 | 
						|
CoreInt8Functions* MNNGetInt8CoreFunctions() {
 | 
						|
    return gCoreFunc;
 | 
						|
}
 | 
						|
};
 |