mirror of https://github.com/alibaba/MNN.git
				
				
				
			
		
			
				
	
	
		
			421 lines
		
	
	
		
			17 KiB
		
	
	
	
		
			C++
		
	
	
	
			
		
		
	
	
			421 lines
		
	
	
		
			17 KiB
		
	
	
	
		
			C++
		
	
	
	
//
 | 
						|
//  WinogradInt8Helper.cpp
 | 
						|
//  MNN
 | 
						|
//
 | 
						|
//  Created by MNN on 2018/07/16.
 | 
						|
//  Copyright © 2018, Alibaba Group Holding Limited
 | 
						|
//
 | 
						|
 | 
						|
#if __GNUC__ == 4
 | 
						|
#pragma GCC optimize("-flax-vector-conversions")
 | 
						|
#endif
 | 
						|
 | 
						|
#include <limits>
 | 
						|
#include <vector>
 | 
						|
#include <map>
 | 
						|
#include <tuple>
 | 
						|
#include <functional>
 | 
						|
#include "WinogradInt8Helper.hpp"
 | 
						|
#include "Int8FunctionsOpt.h"
 | 
						|
#include "core/TensorUtils.hpp"
 | 
						|
#include "math/Vec.hpp"
 | 
						|
 | 
						|
namespace MNN {
 | 
						|
 | 
						|
#if (defined(MNN_USE_NEON) && defined(__aarch64__)) || defined(MNN_USE_SSE)
 | 
						|
using VecType = MNN::Math::Vec<int8_t, 16>;
 | 
						|
static inline void TRANS_4x4(VecType& vec0, VecType& vec1, VecType& vec2, VecType& vec3) {
 | 
						|
#if defined(MNN_USE_SSE)
 | 
						|
    auto m0 = _mm_castsi128_ps(vec0.value);
 | 
						|
    auto m1 = _mm_castsi128_ps(vec1.value);
 | 
						|
    auto m2 = _mm_castsi128_ps(vec2.value);
 | 
						|
    auto m3 = _mm_castsi128_ps(vec3.value);
 | 
						|
    _MM_TRANSPOSE4_PS(m0, m1, m2, m3);
 | 
						|
    vec0.value = _mm_castps_si128(m0);
 | 
						|
    vec1.value = _mm_castps_si128(m1);
 | 
						|
    vec2.value = _mm_castps_si128(m2);
 | 
						|
    vec3.value = _mm_castps_si128(m3);
 | 
						|
#else
 | 
						|
    auto m0 = vtrn1q_s32(vec0.value, vec1.value), m1 = vtrn2q_s32(vec0.value, vec1.value);
 | 
						|
    auto m2 = vtrn1q_s32(vec2.value, vec3.value), m3 = vtrn2q_s32(vec2.value, vec3.value);
 | 
						|
    vec0.value = vtrn1q_s64(m0, m2);
 | 
						|
    vec1.value = vtrn1q_s64(m1, m3);
 | 
						|
    vec2.value = vtrn2q_s64(m0, m2);
 | 
						|
    vec3.value = vtrn2q_s64(m1, m3);
 | 
						|
#endif
 | 
						|
}
 | 
						|
#endif
 | 
						|
 | 
						|
// winograd source transform with simd, C4 -> C4
 | 
						|
static void _sourceTransUnit4x4Pack4x4(const int8_t* srcStart, int8_t* dstStart, size_t srcZStep, size_t dstXStep, size_t dstZStep, size_t countC4) {
 | 
						|
#if defined(MNN_USE_NEON) && defined(__aarch64__)
 | 
						|
    using VecType = MNN::Math::Vec<int8_t, 16>;
 | 
						|
    int countUnit = countC4 / 4, countRemain = countC4 % 4;
 | 
						|
    for (int z = 0; z < countUnit; ++z) {
 | 
						|
        // load, then 4x int8x4 => 1x int8x16, then do simd compute, save
 | 
						|
        VecType in[4] = {
 | 
						|
            VecType::load(srcStart + 0 * srcZStep),
 | 
						|
            VecType::load(srcStart + 1 * srcZStep),
 | 
						|
            VecType::load(srcStart + 2 * srcZStep),
 | 
						|
            VecType::load(srcStart + 3 * srcZStep)
 | 
						|
        };
 | 
						|
        TRANS_4x4(in[0], in[1], in[2], in[3]);
 | 
						|
        VecType m[4] = {
 | 
						|
            in[0] - in[2],
 | 
						|
            in[1] + in[2],
 | 
						|
            in[2] - in[1],
 | 
						|
            in[3] - in[1]
 | 
						|
        };
 | 
						|
        for (int i = 0; i < 4; ++i) {
 | 
						|
            auto tmp = vreinterpretq_s32_s8(m[i].value);
 | 
						|
            vst1q_lane_s32((int32_t*)(dstStart + 0 * dstZStep), tmp, 0);
 | 
						|
            vst1q_lane_s32((int32_t*)(dstStart + 1 * dstZStep), tmp, 1);
 | 
						|
            vst1q_lane_s32((int32_t*)(dstStart + 2 * dstZStep), tmp, 2);
 | 
						|
            vst1q_lane_s32((int32_t*)(dstStart + 3 * dstZStep), tmp, 3);
 | 
						|
            dstStart += dstXStep;
 | 
						|
        }
 | 
						|
        dstStart -= dstXStep * 4;
 | 
						|
        srcStart += srcZStep * 4;
 | 
						|
    }
 | 
						|
#else
 | 
						|
    int countUnit = 0, countRemain = countC4;
 | 
						|
#endif
 | 
						|
    // simd accelerate can't be used
 | 
						|
    using VecType1 = MNN::Math::Vec<int8_t, 4>;
 | 
						|
    for (int i = 0; i < countRemain; ++i) {
 | 
						|
        auto s0 = VecType1::load(srcStart + 0 * 4);
 | 
						|
        auto s1 = VecType1::load(srcStart + 1 * 4);
 | 
						|
        auto s2 = VecType1::load(srcStart + 2 * 4);
 | 
						|
        auto s3 = VecType1::load(srcStart + 3 * 4);
 | 
						|
        VecType1::save(dstStart + 0 * dstXStep, s0 - s2);
 | 
						|
        VecType1::save(dstStart + 1 * dstXStep, s1 + s2);
 | 
						|
        VecType1::save(dstStart + 2 * dstXStep, s2 - s1);
 | 
						|
        VecType1::save(dstStart + 3 * dstXStep, s3 - s1);
 | 
						|
        
 | 
						|
        srcStart += srcZStep;
 | 
						|
        dstStart += dstZStep;
 | 
						|
    }
 | 
						|
}
 | 
						|
 | 
						|
// winograd source transform with simd, fused C4 -> C16 pack
 | 
						|
static void _sourceTransUnit4x4Pack4x16(const int8_t* srcStart, int8_t* dstStart, size_t srcZStep, size_t dstXStep, size_t dstZStep, size_t countC4) {
 | 
						|
    using VecType = MNN::Math::Vec<int8_t, 16>;
 | 
						|
#if (defined(MNN_USE_NEON) && defined(__aarch64__)) || defined(MNN_USE_SSE)
 | 
						|
    int countUnit = countC4 / 4, countRemain = countC4 % 4;
 | 
						|
    for (int z = 0; z < countUnit; ++z) {
 | 
						|
        // load, then 4x int8x4 => 1x int8x16, then do simd compute, save
 | 
						|
        auto in0 = VecType::load(srcStart + 0 * srcZStep);
 | 
						|
        auto in1 = VecType::load(srcStart + 1 * srcZStep);
 | 
						|
        auto in2 = VecType::load(srcStart + 2 * srcZStep);
 | 
						|
        auto in3 = VecType::load(srcStart + 3 * srcZStep);
 | 
						|
        TRANS_4x4(in0, in1, in2, in3);
 | 
						|
        VecType::save(dstStart + 0 * dstXStep, in0 - in2);
 | 
						|
        VecType::save(dstStart + 1 * dstXStep, in1 + in2);
 | 
						|
        VecType::save(dstStart + 2 * dstXStep, in2 - in1);
 | 
						|
        VecType::save(dstStart + 3 * dstXStep, in3 - in1);
 | 
						|
        srcStart += srcZStep * 4;
 | 
						|
        dstStart += dstZStep;
 | 
						|
    }
 | 
						|
#else
 | 
						|
    int countUnit = 0, countRemain = countC4;
 | 
						|
#endif
 | 
						|
    // simd accelerate can't be used
 | 
						|
    for (int i = 0; i < countRemain * 4; ++i) {
 | 
						|
        auto srcZ = srcStart + (i / 4) * srcZStep + (i % 4);
 | 
						|
        auto dstZ = dstStart + (i / 16) * dstZStep + (i % 16);
 | 
						|
        int8_t src[4];
 | 
						|
        for (int j = 0; j < 4; ++j) {
 | 
						|
            src[j] = srcZ[j * 4];
 | 
						|
        }
 | 
						|
        dstZ[0 * dstXStep] = src[0] - src[2];
 | 
						|
        dstZ[1 * dstXStep] = src[1] + src[2];
 | 
						|
        dstZ[2 * dstXStep] = src[2] - src[1];
 | 
						|
        dstZ[3 * dstXStep] = src[3] - src[1];
 | 
						|
    }
 | 
						|
}
 | 
						|
 | 
						|
// winograd source transform with simd, fused C16 -> C4 pack
 | 
						|
static void _sourceTransUnit4x4Pack16x4(const int8_t* srcStart, int8_t* dstStart, size_t srcZStep, size_t dstXStep, size_t dstZStep, size_t countC4) {
 | 
						|
    using VecType = MNN::Math::Vec<int8_t, 16>;
 | 
						|
#if defined(MNN_USE_NEON) && defined(__aarch64__)
 | 
						|
    int countUnit = countC4 / 4, countRemain = countC4 % 4;
 | 
						|
    for (int z = 0; z < countUnit; ++z) {
 | 
						|
        // load 1x int8x16, then do simd compute, then 1x int8x16 => 4x int8x4, save
 | 
						|
        VecType in[4] = {
 | 
						|
            VecType::load(srcStart + 0 * srcZStep),
 | 
						|
            VecType::load(srcStart + 1 * srcZStep),
 | 
						|
            VecType::load(srcStart + 2 * srcZStep),
 | 
						|
            VecType::load(srcStart + 3 * srcZStep)
 | 
						|
        };
 | 
						|
        VecType m[4] = {
 | 
						|
            in[0] - in[2],
 | 
						|
            in[1] + in[2],
 | 
						|
            in[2] - in[1],
 | 
						|
            in[3] - in[1]
 | 
						|
        };
 | 
						|
        for (int i = 0; i < 4; ++i) {
 | 
						|
            auto tmp = vreinterpretq_s32_s8(m[i].value);
 | 
						|
            vst1q_lane_s32((int32_t*)(dstStart + 0 * dstZStep), tmp, 0);
 | 
						|
            vst1q_lane_s32((int32_t*)(dstStart + 1 * dstZStep), tmp, 1);
 | 
						|
            vst1q_lane_s32((int32_t*)(dstStart + 2 * dstZStep), tmp, 2);
 | 
						|
            vst1q_lane_s32((int32_t*)(dstStart + 3 * dstZStep), tmp, 3);
 | 
						|
            dstStart += dstXStep;
 | 
						|
        }
 | 
						|
        dstStart -= dstXStep * 4;
 | 
						|
        srcStart += srcZStep * 4;
 | 
						|
    }
 | 
						|
#else
 | 
						|
    int countUnit = 0, countRemain = countC4;
 | 
						|
#endif
 | 
						|
    // simd accelerate can't be used
 | 
						|
    for (int i = 0; i < countRemain * 4; ++i) {
 | 
						|
        auto srcZ = srcStart + (i / 4) * srcZStep + (i % 4);
 | 
						|
        auto dstZ = dstStart + (i / 16) * dstZStep + (i % 16);
 | 
						|
        int8_t src[4];
 | 
						|
        for (int j = 0; j < 4; ++j) {
 | 
						|
            src[j] = srcZ[j * 4];
 | 
						|
        }
 | 
						|
        dstZ[0 * dstXStep] = src[0] - src[2];
 | 
						|
        dstZ[1 * dstXStep] = src[1] + src[2];
 | 
						|
        dstZ[2 * dstXStep] = src[2] - src[1];
 | 
						|
        dstZ[3 * dstXStep] = src[3] - src[1];
 | 
						|
    }
 | 
						|
}
 | 
						|
 | 
						|
// winograd source transform with simd, C16 -> C16, countC4 = UP_DIV(count, 16)
 | 
						|
static void _sourceTransUnit4x4Pack16x16(const int8_t* srcStart, int8_t* dstStart, size_t srcZStep, size_t dstXStep, size_t dstZStep, size_t countC4) {
 | 
						|
    using VecType = MNN::Math::Vec<int8_t, 16>;
 | 
						|
    for (int i = 0; i < countC4; ++i) {
 | 
						|
        auto s0 = VecType::load(srcStart + 0 * 16);
 | 
						|
        auto s1 = VecType::load(srcStart + 1 * 16);
 | 
						|
        auto s2 = VecType::load(srcStart + 2 * 16);
 | 
						|
        auto s3 = VecType::load(srcStart + 3 * 16);
 | 
						|
        VecType::save(dstStart + 0 * dstXStep, s0 - s2);
 | 
						|
        VecType::save(dstStart + 1 * dstXStep, s1 + s2);
 | 
						|
        VecType::save(dstStart + 2 * dstXStep, s2 - s1);
 | 
						|
        VecType::save(dstStart + 3 * dstXStep, s3 - s1);
 | 
						|
        
 | 
						|
        srcStart += srcZStep;
 | 
						|
        dstStart += dstZStep;
 | 
						|
    }
 | 
						|
}
 | 
						|
 | 
						|
WinogradInt8Helper::SrcTransFunc WinogradInt8Helper::chooseSourceTransform(int alpha, int inPack, int outPack) {
 | 
						|
    std::map<std::tuple<int, int, int>, WinogradInt8Helper::SrcTransFunc> func_table = {
 | 
						|
        {std::make_tuple(4, 4, 16), _sourceTransUnit4x4Pack4x16},
 | 
						|
        {std::make_tuple(4, 16, 4), _sourceTransUnit4x4Pack16x4},
 | 
						|
        {std::make_tuple(4, 4, 4), _sourceTransUnit4x4Pack4x4},
 | 
						|
        {std::make_tuple(4, 16, 16), _sourceTransUnit4x4Pack16x16}
 | 
						|
    };
 | 
						|
    auto func_iter = func_table.find(std::make_tuple(alpha, inPack, outPack));
 | 
						|
    if (func_iter == func_table.end()) {
 | 
						|
        return nullptr;
 | 
						|
    }
 | 
						|
    return func_iter->second;
 | 
						|
}
 | 
						|
 | 
						|
static void _destTransformUnit4x2(const float* srcStart, float* dstStart, size_t srcXStep, size_t srcZStep, size_t dstXStep, size_t dstZStep, size_t countC4) {
 | 
						|
    using VecType = MNN::Math::Vec<float, 4>;
 | 
						|
    VecType c0(0.5f);
 | 
						|
    for (int i = 0; i < countC4; ++i) {
 | 
						|
        auto x0 = VecType::load(srcStart + srcXStep * 0);
 | 
						|
        auto x1 = VecType::load(srcStart + srcXStep * 1);
 | 
						|
        auto x2 = VecType::load(srcStart + srcXStep * 2);
 | 
						|
        auto x3 = VecType::load(srcStart + srcXStep * 3);
 | 
						|
        auto m0 = x0, m1 = x3;
 | 
						|
        VecType::mla(m0, x1 + x2, c0);
 | 
						|
        VecType::mla(m1, x1 - x2, c0);
 | 
						|
        VecType::save(dstStart + dstXStep * 0, m0);
 | 
						|
        VecType::save(dstStart + dstXStep * 1, m1);
 | 
						|
        
 | 
						|
        srcStart += srcZStep;
 | 
						|
        dstStart += dstZStep;
 | 
						|
    }
 | 
						|
}
 | 
						|
 | 
						|
static void _destTransformUnit4x3(const float* srcStart, float* dstStart, size_t srcXStep, size_t srcZStep, size_t dstXStep, size_t dstZStep, size_t countC4) {
 | 
						|
    using VecType = MNN::Math::Vec<float, 4>;
 | 
						|
    VecType c0(0.5f);
 | 
						|
    for (int i = 0; i < countC4; ++i) {
 | 
						|
        auto x0 = VecType::load(srcStart + srcXStep * 0);
 | 
						|
        auto x1 = VecType::load(srcStart + srcXStep * 1);
 | 
						|
        auto x2 = VecType::load(srcStart + srcXStep * 2);
 | 
						|
        auto x3 = VecType::load(srcStart + srcXStep * 3);
 | 
						|
        auto m0 = x0 + (x1 + x2) * 0.5;
 | 
						|
        auto m1 = (x1 - x2) * 0.5;
 | 
						|
        auto m2 = x3 + (x1 + x2) * 0.5;
 | 
						|
        VecType::save(dstStart + dstXStep * 0, m0);
 | 
						|
        VecType::save(dstStart + dstXStep * 1, m1);
 | 
						|
        VecType::save(dstStart + dstXStep * 2, m2);
 | 
						|
        
 | 
						|
        srcStart += srcZStep;
 | 
						|
        dstStart += dstZStep;
 | 
						|
    }
 | 
						|
}
 | 
						|
 | 
						|
WinogradInt8Helper::DstTransFunc WinogradInt8Helper::chooseDestTransform(int alpha, int unit) {
 | 
						|
    std::map<std::tuple<int, int>, WinogradInt8Helper::DstTransFunc> func_table = {
 | 
						|
        {std::make_tuple(4, 2), _destTransformUnit4x2},
 | 
						|
        {std::make_tuple(4, 3), _destTransformUnit4x3},
 | 
						|
    };
 | 
						|
    auto func_iter = func_table.find(std::make_tuple(alpha, unit));
 | 
						|
    if (func_iter == func_table.end()) {
 | 
						|
        return nullptr;
 | 
						|
    }
 | 
						|
    return func_iter->second;
 | 
						|
}
 | 
						|
 | 
						|
typedef bool(*WeightTransFunc)(const int8_t* srcStart, int8_t* dstStart, size_t srcStep, size_t dstStep);
 | 
						|
 | 
						|
static bool _weightTransUnit3x4(const int8_t* srcStart, int8_t* dstStart, size_t srcStep, size_t dstStep) {
 | 
						|
    int32_t x[3], m[4];
 | 
						|
    for (int i = 0; i < 3; ++i) {
 | 
						|
        x[i] = (int32_t)(srcStart[i * srcStep]);
 | 
						|
    }
 | 
						|
    m[0] = x[0];
 | 
						|
    m[1] = x[0] + x[1] + x[2];
 | 
						|
    m[2] = x[0] - x[1] + x[2];
 | 
						|
    m[3] = x[2];
 | 
						|
    bool overflow = false;
 | 
						|
    for (int i = 0; i < 4; ++i) {
 | 
						|
        overflow |= (m[i] < std::numeric_limits<int8_t>::min() || m[i] > std::numeric_limits<int8_t>::max());
 | 
						|
        dstStart[i * dstStep] = (int8_t)m[i];
 | 
						|
    }
 | 
						|
    return overflow;
 | 
						|
}
 | 
						|
 | 
						|
static bool _weightTransUnit2x4(const int8_t* srcStart, int8_t* dstStart, size_t srcStep, size_t dstStep) {
 | 
						|
    int32_t x[2], m[4];
 | 
						|
    for (int i = 0; i < 2; ++i) {
 | 
						|
        x[i] = (int32_t)(srcStart[i * srcStep]);
 | 
						|
    }
 | 
						|
    m[0] = x[0];
 | 
						|
    m[1] = x[0] + x[1];
 | 
						|
    m[2] = x[0] - x[1];
 | 
						|
    m[3] = x[1];
 | 
						|
    bool overflow = false;
 | 
						|
    for (int i = 0; i < 4; ++i) {
 | 
						|
        overflow |= (m[i] < std::numeric_limits<int8_t>::min() || m[i] > std::numeric_limits<int8_t>::max());
 | 
						|
        dstStart[i * dstStep] = (int8_t)m[i];
 | 
						|
    }
 | 
						|
    return overflow;
 | 
						|
}
 | 
						|
 | 
						|
static WeightTransFunc _chooseWeightTransform(int alpha, int kernel) {
 | 
						|
    std::map<std::tuple<int, int>, WeightTransFunc> func_table = {
 | 
						|
        {std::make_tuple(4, 3), _weightTransUnit3x4},
 | 
						|
        {std::make_tuple(4, 2), _weightTransUnit2x4},
 | 
						|
    };
 | 
						|
    auto func_iter = func_table.find(std::make_tuple(alpha, kernel));
 | 
						|
    if (func_iter == func_table.end()) {
 | 
						|
        return nullptr;
 | 
						|
    }
 | 
						|
    return func_iter->second;
 | 
						|
}
 | 
						|
 | 
						|
WinogradInt8Helper::WinogradInt8Helper(int unitY, int unitX, const Convolution2DCommon* common, const CoreInt8Functions* core) {
 | 
						|
    mCommon = common;
 | 
						|
    mAlphaY = unitY + common->kernelY() - 1;
 | 
						|
    mAlphaX = unitX + common->kernelX() - 1;
 | 
						|
    mInt8Core = core;
 | 
						|
}
 | 
						|
 | 
						|
std::shared_ptr<Tensor> WinogradInt8Helper::allocTransformWeight(const Tensor* weightSrc) {
 | 
						|
    int UNIT, SRC_UNIT, DST_XUNIT;
 | 
						|
    mInt8Core->MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT);
 | 
						|
    int oc4 = UP_DIV(mCommon->outputCount(), UNIT), ic4 = UP_DIV(mCommon->inputCount(), SRC_UNIT);
 | 
						|
    return std::shared_ptr<Tensor>(Tensor::createDevice<int8_t>({mAlphaY, mAlphaX, oc4, ic4, UNIT, SRC_UNIT}));
 | 
						|
}
 | 
						|
// whether transform success without overflow, only detect overflow when weightDst == nullptr
 | 
						|
bool WinogradInt8Helper::transformWeight(const Tensor* weightSrc, Tensor* weightDst) {
 | 
						|
    bool fake = (weightDst == nullptr); // fake transform, only for detect overflow
 | 
						|
    int UNIT, SRC_UNIT, DST_XUNIT;
 | 
						|
    mInt8Core->MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT);
 | 
						|
    int oc = mCommon->outputCount(), ic = mCommon->inputCount();
 | 
						|
    int kernelY = mCommon->kernelY(), kernelX = mCommon->kernelX();
 | 
						|
    auto transFuncY = _chooseWeightTransform(mAlphaY, kernelY);
 | 
						|
    auto transFuncX = _chooseWeightTransform(mAlphaX, kernelX);
 | 
						|
    mValid = (transFuncY != nullptr || kernelY == 1);
 | 
						|
    mValid &= (transFuncX != nullptr || kernelX == 1);
 | 
						|
    if (!mValid) {
 | 
						|
        return mValid;
 | 
						|
    }
 | 
						|
    // assign new T[xx] to shared_ptr<T[]> is not support due to bug of some compiler (c++11)
 | 
						|
    // so not use: std::shared_ptr<int8_t[]> cache(new int8_t[xx])
 | 
						|
    std::shared_ptr<int8_t> cache(new int8_t[mAlphaY * kernelX + mAlphaY * mAlphaX * UNIT * SRC_UNIT],
 | 
						|
                                  [](int8_t* ptr) { delete[] ptr; });
 | 
						|
    int dstYStep = (fake ? 0 : weightDst->stride(0)), dstXStep = (fake ? 0 : weightDst->stride(1));
 | 
						|
    int dstOZStep = (fake ? 0 : weightDst->stride(2)), dstSZStep = (fake ? 0 : weightDst->stride(3));
 | 
						|
    int8_t* dataDstOrigin;
 | 
						|
    if (fake) {
 | 
						|
        dataDstOrigin = cache.get() + mAlphaY * kernelX;
 | 
						|
        memset(dataDstOrigin, 0, mAlphaY * mAlphaX * UNIT * SRC_UNIT);
 | 
						|
    } else {
 | 
						|
        dataDstOrigin = weightDst->host<int8_t>();
 | 
						|
        memset(dataDstOrigin, 0, weightDst->size());
 | 
						|
    }
 | 
						|
    
 | 
						|
    bool overflow = false;
 | 
						|
    for (int oz = 0; oz < oc; ++oz) {
 | 
						|
        int oz4 = oz / UNIT, ozRemain = oz % UNIT;
 | 
						|
        for (int sz = 0; sz < ic; ++sz) {
 | 
						|
            int sz4 = sz / SRC_UNIT, szRemain = sz % SRC_UNIT;
 | 
						|
            auto dataSrcZ = weightSrc->host<int8_t>() + (ic * oz + sz) * kernelY * kernelX;
 | 
						|
            auto dataDstZ = dataDstOrigin + oz4 * dstOZStep + sz4 * dstSZStep + ozRemain * SRC_UNIT + szRemain;
 | 
						|
            for (int i = 0; i < kernelX; ++i) {
 | 
						|
                if (kernelY != 1) {
 | 
						|
                    overflow |= transFuncY(dataSrcZ + i, cache.get() + i, kernelX, kernelX);
 | 
						|
                } else {
 | 
						|
                    cache.get()[i] = dataSrcZ[i];
 | 
						|
                }
 | 
						|
            }
 | 
						|
            int yLen = (kernelY == 1 ? 1 : mAlphaY);
 | 
						|
            for (int i = 0; i < yLen; ++i) {
 | 
						|
                if (kernelX != 1) {
 | 
						|
                    overflow |= transFuncX(cache.get() + i * kernelX, dataDstZ + i * dstYStep, 1, dstXStep);
 | 
						|
                } else {
 | 
						|
                    dataDstZ[i * dstYStep] = cache.get()[i * kernelX];
 | 
						|
                }
 | 
						|
            }
 | 
						|
        }
 | 
						|
    }
 | 
						|
    return !overflow;
 | 
						|
}
 | 
						|
// when overflow occur, return true
 | 
						|
bool WinogradInt8Helper::weightOverflow(const Tensor* weight, int unitY, int unitX, const Convolution2DCommon* common, const CoreInt8Functions* core) {
 | 
						|
    WinogradInt8Helper helper(unitY, unitX, common, core);
 | 
						|
    return !(helper.transformWeight(weight, nullptr));
 | 
						|
}
 | 
						|
// when overflow occur or not support, return true
 | 
						|
bool WinogradInt8Helper::featureOverflow(const Tensor* input, int alphaY, int alphaX) {
 | 
						|
    std::map<int, std::pair<int8_t, int8_t>> limit2D = {
 | 
						|
#ifdef MNN_USE_SSE
 | 
						|
        {4, {-32, 31}} // int6
 | 
						|
#else
 | 
						|
        {4, {-64, 63}} // int6
 | 
						|
#endif
 | 
						|
    }, limit1D = {
 | 
						|
#ifdef MNN_USE_SSE
 | 
						|
        {4, {-64, 63}} // int7
 | 
						|
#else
 | 
						|
        {4, {-128, 127}} // int7
 | 
						|
#endif
 | 
						|
    };
 | 
						|
    auto quantAttr = TensorUtils::getDescribe(input)->quantAttr;
 | 
						|
    if (quantAttr == nullptr) {
 | 
						|
        MNN_ERROR("Tensor quantAttr should not be nullptr\n");
 | 
						|
        return true;
 | 
						|
    }
 | 
						|
    auto iter = limit2D.end();
 | 
						|
    if (alphaY == 1 || alphaX == 1) {
 | 
						|
        iter = limit1D.find(ALIMAX(alphaY, alphaX));
 | 
						|
    } else if (alphaY == alphaX) {
 | 
						|
        iter = limit2D.find(alphaY);
 | 
						|
    }
 | 
						|
    
 | 
						|
    bool overflow = (quantAttr->min < iter->second.first || quantAttr->max > iter->second.second);
 | 
						|
    return overflow;
 | 
						|
}
 | 
						|
 | 
						|
}
 |