| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | //
 | 
					
						
							|  |  |  | //  CPUWhere.cpp
 | 
					
						
							|  |  |  | //  MNN
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | //  Created by MNN on 2018/08/31.
 | 
					
						
							|  |  |  | //  Copyright © 2018, Alibaba Group Holding Limited
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  | #include "backend/cpu/CPUWhere.hpp"
 | 
					
						
							|  |  |  | #include "backend/cpu/CPUBackend.hpp"
 | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | namespace MNN { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ErrorCode CPUWhere::onExecute(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) { | 
					
						
							| 
									
										
										
										
											2019-07-04 19:38:23 +08:00
										 |  |  |     auto& ib           = inputs[0]->buffer(); | 
					
						
							|  |  |  |     int32_t* inputData = inputs[0]->host<int32_t>(); | 
					
						
							|  |  |  |     auto outputData    = outputs[0]->host<int32_t>(); | 
					
						
							| 
									
										
										
										
											2020-07-04 01:21:30 +08:00
										 |  |  |     auto inputTotal = inputs[0]->elementSize(); | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     std::vector<int32_t> trueVec; | 
					
						
							| 
									
										
										
										
											2020-07-04 01:21:30 +08:00
										 |  |  |     for (int i = 0; i < inputTotal; i++) { | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |         if (inputData[i] > 0) { | 
					
						
							|  |  |  |             trueVec.push_back(i); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-02-07 10:45:07 +08:00
										 |  |  |     //MNN_ASSERT(outputs[0]->batch() == trueVec.size());
 | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |     for (int i = 0; i < trueVec.size(); i++) { | 
					
						
							|  |  |  |         int index = trueVec[i]; | 
					
						
							|  |  |  |         for (int j = 0; j < ib.dimensions; j++) { | 
					
						
							| 
									
										
										
										
											2021-01-06 16:29:37 +08:00
										 |  |  |             int result    = ib.dim[j].stride == 0 ? index : index / ib.dim[j].stride; | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |             index         = index - result * ib.dim[j].stride; | 
					
						
							| 
									
										
										
										
											2020-02-26 09:57:17 +08:00
										 |  |  |             outputData[i * ib.dimensions + j] = result; | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     return NO_ERROR; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class CPUWhereCreator : public CPUBackend::Creator { | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  |     virtual Execution* onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs, | 
					
						
							|  |  |  |                                 const MNN::Op* op, Backend* backend) const override { | 
					
						
							|  |  |  |         return new CPUWhere(backend); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | REGISTER_CPU_OP_CREATOR(CPUWhereCreator, OpType_Where); | 
					
						
							|  |  |  | } // namespace MNN
 |