| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | //
 | 
					
						
							|  |  |  | //  ShapeMatMul.cpp
 | 
					
						
							|  |  |  | //  MNN
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | //  Created by MNN on 2019/01/10.
 | 
					
						
							|  |  |  | //  Copyright © 2018, Alibaba Group Holding Limited
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  | #include "shape/SizeComputer.hpp"
 | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  | #include "core/Macro.h"
 | 
					
						
							|  |  |  | #include "core/TensorUtils.hpp"
 | 
					
						
							| 
									
										
										
										
											2023-12-04 11:12:20 +08:00
										 |  |  | #include "core/OpCommonUtils.hpp"
 | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | namespace MNN { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class MatMulSizeComputer : public SizeComputer { | 
					
						
							| 
									
										
										
										
											2023-12-04 11:12:20 +08:00
										 |  |  |     static void _getTranspose(const MNN::Op* op, bool& transposeA, bool& transposeB) { | 
					
						
							|  |  |  |         transposeA = false; | 
					
						
							|  |  |  |         transposeB = false; | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |         if (op->type() == OpType_MatMul) { | 
					
						
							|  |  |  |             transposeA = op->main_as_MatMul()->transposeA(); | 
					
						
							|  |  |  |             transposeB = op->main_as_MatMul()->transposeB(); | 
					
						
							|  |  |  |         } else { | 
					
						
							|  |  |  |             // BatchMatMul
 | 
					
						
							|  |  |  |             transposeA = op->main_as_BatchMatMulParam()->adjX(); | 
					
						
							|  |  |  |             transposeB = op->main_as_BatchMatMulParam()->adjY(); | 
					
						
							|  |  |  |         } | 
					
						
							| 
									
										
										
										
											2023-12-04 11:12:20 +08:00
										 |  |  |     } | 
					
						
							|  |  |  |     virtual bool onComputeSize(const MNN::Op* op, const std::vector<Tensor*>& inputs, | 
					
						
							|  |  |  |                                const std::vector<Tensor*>& outputs) const override { | 
					
						
							|  |  |  |         MNN_ASSERT(1 == outputs.size()); | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |         auto output = outputs[0]; | 
					
						
							| 
									
										
										
										
											2020-05-14 13:19:30 +08:00
										 |  |  |         output->buffer().type = inputs[0]->buffer().type; | 
					
						
							| 
									
										
										
										
											2023-12-04 11:12:20 +08:00
										 |  |  |         bool transposeA; | 
					
						
							|  |  |  |         bool transposeB; | 
					
						
							|  |  |  |         _getTranspose(op, transposeA, transposeB); | 
					
						
							|  |  |  |         int e, l, h; | 
					
						
							|  |  |  |         bool valid = OpCommonUtils::computeMatMulSize(transposeA, transposeB, inputs[0], inputs[1], e, l, h); | 
					
						
							|  |  |  |         if (!valid) { | 
					
						
							| 
									
										
										
										
											2019-06-17 20:10:35 +08:00
										 |  |  |             return false; | 
					
						
							|  |  |  |         } | 
					
						
							| 
									
										
										
										
											2020-05-14 13:19:30 +08:00
										 |  |  |         // Compute BroastCast Dims
 | 
					
						
							| 
									
										
										
										
											2023-12-04 11:12:20 +08:00
										 |  |  |         auto i0Dim = inputs[0]->dimensions(); | 
					
						
							|  |  |  |         auto i1Dim = inputs[1]->dimensions(); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-05-14 13:19:30 +08:00
										 |  |  |         auto input0 = inputs[0]; | 
					
						
							|  |  |  |         auto input1 = inputs[1]; | 
					
						
							|  |  |  |         auto o0Dim = i0Dim; | 
					
						
							|  |  |  |         if (i1Dim > i0Dim) { | 
					
						
							|  |  |  |             o0Dim = i1Dim; | 
					
						
							|  |  |  |             input0 = inputs[1]; | 
					
						
							|  |  |  |             input1 = inputs[0]; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         auto dimOffset = o0Dim - 2; | 
					
						
							|  |  |  |         output->buffer().dimensions = o0Dim; | 
					
						
							|  |  |  |         const int maxDimensions = dimOffset; | 
					
						
							|  |  |  |         const int diffDimension = input0->dimensions() - input1->dimensions(); | 
					
						
							|  |  |  |          | 
					
						
							|  |  |  |         for (int i = 0; i < maxDimensions; i++) { | 
					
						
							|  |  |  |             output->setLength(i, input0->length(i)); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         for (int i = diffDimension; i < maxDimensions; i++) { | 
					
						
							|  |  |  |             const int input1Index = i - diffDimension; | 
					
						
							|  |  |  |             int dim1 = input1->buffer().dim[input1Index].extent; | 
					
						
							|  |  |  |             if (dim1 != output->length(i) && (dim1 != 1 && output->length(i) != 1)) { | 
					
						
							|  |  |  |                 MNN_PRINT("Don't support broadcast for MatMulOp, i0=%d, i1=%d\n", output->length(i), dim1); | 
					
						
							|  |  |  |                 return false; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             if (dim1 == output->length(i)) { | 
					
						
							|  |  |  |                 continue; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             if (dim1 != output->length(i) && (dim1 == 1 || output->length(i) == 1)) { | 
					
						
							|  |  |  |                 output->setLength(i, output->length(i) * dim1); | 
					
						
							|  |  |  |             } else { | 
					
						
							|  |  |  |                 MNN_PRINT("Error, the logic flow should never get here"); | 
					
						
							|  |  |  |                 return false; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         // Last Two dim
 | 
					
						
							| 
									
										
										
										
											2023-12-04 11:12:20 +08:00
										 |  |  |         output->setLength(o0Dim - 2, e); | 
					
						
							|  |  |  |         output->setLength(o0Dim - 1, h); | 
					
						
							|  |  |  |         bool eValid = inputs[0]->dimensions() > 1; | 
					
						
							|  |  |  |         bool hValid = inputs[1]->dimensions() > 1; | 
					
						
							|  |  |  |         int squeezeDim = 0; | 
					
						
							|  |  |  |         if (!eValid) { | 
					
						
							|  |  |  |             squeezeDim++; | 
					
						
							|  |  |  |             output->setLength(o0Dim - 2, h); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         if (!hValid) { | 
					
						
							|  |  |  |             squeezeDim++; | 
					
						
							|  |  |  |             output->setLength(o0Dim - 1, e); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         if (squeezeDim > 0) { | 
					
						
							|  |  |  |             output->buffer().dimensions = o0Dim - squeezeDim; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-08-22 20:13:46 +08:00
										 |  |  |         TensorUtils::getDescribe(output)->dimensionFormat = TensorUtils::getDescribe(inputs[0])->dimensionFormat; | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |         return true; | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |     virtual float onComputeFlops(const MNN::Op* op, const std::vector<Tensor*>& inputs, | 
					
						
							|  |  |  |                                  const std::vector<Tensor*>& outputs) const override { | 
					
						
							| 
									
										
										
										
											2023-12-04 11:12:20 +08:00
										 |  |  |         bool transposeA; | 
					
						
							|  |  |  |         bool transposeB; | 
					
						
							|  |  |  |         _getTranspose(op, transposeA, transposeB); | 
					
						
							|  |  |  |         int e=0, l=0, h=0; | 
					
						
							|  |  |  |         OpCommonUtils::computeMatMulSize(transposeA, transposeB, inputs[0], inputs[1], e, l, h); | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |         Tensor* C       = outputs[0]; | 
					
						
							|  |  |  |         auto flops = (float)e * l * h / FLOPS_M; | 
					
						
							| 
									
										
										
										
											2023-12-04 11:12:20 +08:00
										 |  |  |         bool eValid = inputs[0]->dimensions() > 1; | 
					
						
							|  |  |  |         bool hValid = inputs[1]->dimensions() > 1; | 
					
						
							|  |  |  |         int squeezeDim = 0; | 
					
						
							|  |  |  |         if (!eValid) { | 
					
						
							|  |  |  |             squeezeDim++; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         if (!hValid) { | 
					
						
							|  |  |  |             squeezeDim++; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         for (int i=0; i<C->dimensions() - 2 + squeezeDim; ++i) { | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |             flops *= C->length(i); | 
					
						
							|  |  |  |         } | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |         return flops; | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | REGISTER_SHAPE(MatMulSizeComputer, OpType_MatMul); | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  | REGISTER_SHAPE(MatMulSizeComputer, OpType_BatchMatMul); | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | } // namespace MNN
 |