2019-04-17 10:49:11 +08:00
|
|
|
|
//
|
|
|
|
|
// StrassenMatmulComputor.hpp
|
|
|
|
|
// MNN
|
|
|
|
|
//
|
|
|
|
|
// Created by MNN on 2019/02/11.
|
|
|
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
|
|
|
//
|
|
|
|
|
|
|
|
|
|
#ifndef StrassenMatmulComputor_hpp
|
|
|
|
|
#define StrassenMatmulComputor_hpp
|
|
|
|
|
|
2019-05-05 20:27:57 +08:00
|
|
|
|
#include <functional>
|
2019-12-27 22:16:57 +08:00
|
|
|
|
#include "core/Backend.hpp"
|
2019-04-17 10:49:11 +08:00
|
|
|
|
namespace MNN {
|
2020-07-04 01:21:30 +08:00
|
|
|
|
/**
|
|
|
|
|
Based on
|
|
|
|
|
Boyer, B., Dumas, J.-G., Pernet, C., & Zhou, W. (2007). Memory efficient scheduling of Strassen-Winogradʼs matrix multiplication algorithm. Proceedings of the 2009 international symposium on Symbolic and algebraic computation ISSAC 09, 55. ACM Press. Retrieved from http://arxiv.org/abs/0707.2347
|
|
|
|
|
|
|
|
|
|
Use Table 2
|
|
|
|
|
*/
|
2019-06-17 20:10:35 +08:00
|
|
|
|
class StrassenMatrixComputor {
|
2019-04-17 10:49:11 +08:00
|
|
|
|
public:
|
2020-02-26 09:57:17 +08:00
|
|
|
|
StrassenMatrixComputor(Backend* bn, bool multithread, int maxDepth);
|
2019-04-17 10:49:11 +08:00
|
|
|
|
virtual ~StrassenMatrixComputor();
|
2019-06-17 20:10:35 +08:00
|
|
|
|
|
2019-04-17 10:49:11 +08:00
|
|
|
|
/*
|
|
|
|
|
It's assume that:
|
2020-07-04 01:21:30 +08:00
|
|
|
|
A is a matrix where each element is a (4,1) vector : lC4, e, 4
|
|
|
|
|
B is a matrix where each element is a (hP,1) vector : h, l, hP
|
2019-04-17 10:49:11 +08:00
|
|
|
|
inputs[0] is the transpose of A: AT, inputs[1] is the transpose of B: BT
|
|
|
|
|
outputs[0] is the transpose of C: CT
|
2020-07-04 01:21:30 +08:00
|
|
|
|
C is a matrix where each element is a (4,1) vector, the same as A : hC4, e, 4
|
2019-04-17 10:49:11 +08:00
|
|
|
|
|
2020-07-04 01:21:30 +08:00
|
|
|
|
if (inputs.size() > 2) {
|
|
|
|
|
inputs[2] is origin CO: CT
|
|
|
|
|
CO can be the same same as C or broadcast in lenght(1): hC4, e, 4 or hC4, 1, 4
|
|
|
|
|
}
|
|
|
|
|
Compute: C = alpha * AB + beta * CO , alpha must be 1.0f
|
|
|
|
|
|
|
|
|
|
postParameters:
|
|
|
|
|
0: alpha
|
|
|
|
|
1: beta
|
|
|
|
|
2: min
|
|
|
|
|
3: max
|
|
|
|
|
|
|
|
|
|
if (postParameters.empty()) {
|
|
|
|
|
alpha = 1.0f
|
|
|
|
|
beta = 0.0f;
|
|
|
|
|
min = -FLT_MAX
|
|
|
|
|
max = FLT_MAX
|
|
|
|
|
}
|
2019-04-17 10:49:11 +08:00
|
|
|
|
*/
|
2021-06-11 17:17:13 +08:00
|
|
|
|
ErrorCode onEncode(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs, const std::vector<float>& postParameters = {}, int l = 0, int h = 0);
|
2019-06-17 20:10:35 +08:00
|
|
|
|
|
2021-06-11 17:17:13 +08:00
|
|
|
|
void onExecute(const uint8_t* AT = nullptr, const uint8_t* BT = nullptr, const uint8_t* COT = nullptr, uint8_t* CT = nullptr);
|
2019-06-17 20:10:35 +08:00
|
|
|
|
|
2021-06-11 17:17:13 +08:00
|
|
|
|
void onReset();
|
2019-06-17 20:10:35 +08:00
|
|
|
|
protected:
|
|
|
|
|
Backend* backend() const {
|
|
|
|
|
return mBackend;
|
|
|
|
|
}
|
2019-04-17 10:49:11 +08:00
|
|
|
|
|
|
|
|
|
private:
|
2021-06-11 17:17:13 +08:00
|
|
|
|
struct MatrixInfo {
|
|
|
|
|
int stackIndex;
|
|
|
|
|
int offsetBytes;
|
|
|
|
|
int lineStrideBytes;
|
|
|
|
|
};
|
|
|
|
|
ErrorCode _generateMatMul(int e, int l, int h, const MatrixInfo& AT, const MatrixInfo& BT, const MatrixInfo& CT, const MatrixInfo& COT, int currentDepth, const std::vector<float>& postParameters);
|
|
|
|
|
ErrorCode _generateTrivalMatMul(int e, int l, int h, const MatrixInfo& AT, const MatrixInfo& BT, const MatrixInfo& CT, const MatrixInfo& COT, const std::vector<float>& postParameters);
|
|
|
|
|
ErrorCode _generateBasicMatMul(int e, int l, int h, const MatrixInfo& AT, const MatrixInfo& BT, const MatrixInfo& CT, const MatrixInfo& COT, const std::vector<float>& postParameters);
|
2019-04-17 10:49:11 +08:00
|
|
|
|
|
2020-02-26 09:57:17 +08:00
|
|
|
|
std::vector<std::pair<std::function<void(int tId)>, int>> mFunctions;
|
2019-04-17 10:49:11 +08:00
|
|
|
|
int mMaxDepth;
|
2020-02-26 09:57:17 +08:00
|
|
|
|
bool mSupportMultiThread;
|
2019-06-17 20:10:35 +08:00
|
|
|
|
|
|
|
|
|
Backend* mBackend;
|
2021-06-11 17:17:13 +08:00
|
|
|
|
|
|
|
|
|
std::vector<uint8_t*> mStack;
|
2019-04-17 10:49:11 +08:00
|
|
|
|
};
|
|
|
|
|
} // namespace MNN
|
|
|
|
|
|
|
|
|
|
#endif /* StrassenMatmulComputor_hpp */
|