mirror of https://github.com/alibaba/MNN.git
95 lines
2.5 KiB
C++
95 lines
2.5 KiB
C++
|
#include "core/Macro.h"
|
||
|
#include "../compute/CommonOptFunction.h"
|
||
|
#ifdef MNN_USE_NEON
|
||
|
#include <arm_neon.h>
|
||
|
|
||
|
void MNNGetMatMulPackMode(int* eP, int *lP, int* hP) {
|
||
|
*eP = 12;
|
||
|
*lP = 1;
|
||
|
#ifdef __aarch64__
|
||
|
*hP = 8;
|
||
|
#else
|
||
|
*hP = 4;
|
||
|
#endif
|
||
|
}
|
||
|
|
||
|
#ifdef __aarch64__
|
||
|
extern "C" {
|
||
|
void MNNPackC8(float* dest, const float* source, size_t l, size_t h);
|
||
|
}
|
||
|
|
||
|
void MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t l, bool transpose) {
|
||
|
auto hP = (int)h / 8;
|
||
|
auto hR = (int)hP * 8;
|
||
|
if (hR != h) {
|
||
|
::memset(dest, 0, UP_DIV(h, 8)*8*l*sizeof(float));
|
||
|
}
|
||
|
if (!transpose) {
|
||
|
for (int y=0; y<hP; ++y) {
|
||
|
auto destY = dest + y * 8 * l;
|
||
|
auto sourceY = source + y * 8;
|
||
|
for (int x=0; x<l; ++x) {
|
||
|
::memcpy(destY + 8 * x, sourceY + x * h, 8 * sizeof(float));
|
||
|
}
|
||
|
}
|
||
|
auto hRemain = h - hR;
|
||
|
if (hRemain > 0) {
|
||
|
auto destY = dest + hP * 8 * l;
|
||
|
auto sourceY = source + hP * 8;
|
||
|
for (int x=0; x<l; ++x) {
|
||
|
::memcpy(destY + 8 * x, sourceY + x * h, hRemain * sizeof(float));
|
||
|
}
|
||
|
}
|
||
|
return;
|
||
|
}
|
||
|
int lC8 = (int)l / 8;
|
||
|
auto lR = lC8 * 8;
|
||
|
if (hP > 0 && lC8 > 0) {
|
||
|
MNNPackC8(dest, source, l, h);
|
||
|
}
|
||
|
for (int y=hR; y<h; ++y) {
|
||
|
auto yR = y % 8;
|
||
|
auto yC = hP;
|
||
|
for (int x=0; x<l; ++x) {
|
||
|
dest[x * 8 + yR + yC * 8 * l] = source[x + y * l];
|
||
|
}
|
||
|
}
|
||
|
for (int y=0; y<hR; ++y) {
|
||
|
auto yR = y % 8;
|
||
|
auto yC = y / 8;
|
||
|
for (int x=lR; x<l; ++x) {
|
||
|
dest[x * 8 + yR + yC * 8 * l] = source[x + y * l];
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
#else
|
||
|
void MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t l, bool transpose) {
|
||
|
if (!transpose) {
|
||
|
auto hP = h / 4;
|
||
|
auto hR = hP * 4;
|
||
|
if (hR != h) {
|
||
|
::memset(dest, 0, UP_DIV(h, 4)*4*l*sizeof(float));
|
||
|
}
|
||
|
for (int y=0; y<hP; ++y) {
|
||
|
auto destY = dest + y * 4 * l;
|
||
|
auto sourceY = source + y * 4;
|
||
|
for (int x=0; x<l; ++x) {
|
||
|
::memcpy(destY + 4 * x, sourceY + x * h, 4 * sizeof(float));
|
||
|
}
|
||
|
}
|
||
|
auto hRemain = h - hR;
|
||
|
if (hRemain > 0) {
|
||
|
auto destY = dest + hP * 4 * l;
|
||
|
auto sourceY = source + hP * 4;
|
||
|
for (int x=0; x<l; ++x) {
|
||
|
::memcpy(destY + 4 * x, sourceY + x * h, hRemain * sizeof(float));
|
||
|
}
|
||
|
}
|
||
|
return;
|
||
|
}
|
||
|
MNNPackC4(dest, source, l, h);
|
||
|
}
|
||
|
#endif
|
||
|
|
||
|
#endif
|