mirror of https://github.com/alibaba/MNN.git
				
				
				
			Fix bug for RemoveInvalidCast for set topkv2 error
This commit is contained in:
		
							parent
							
								
									c4f9daa4e6
								
							
						
					
					
						commit
						fd52fef0d8
					
				| 
						 | 
					@ -13,7 +13,6 @@
 | 
				
			||||||
#include <algorithm>
 | 
					#include <algorithm>
 | 
				
			||||||
#include "../PostTreatUtils.hpp"
 | 
					#include "../PostTreatUtils.hpp"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
class RemoveInvalidCast : public PostConverter {
 | 
					class RemoveInvalidCast : public PostConverter {
 | 
				
			||||||
public:
 | 
					public:
 | 
				
			||||||
    virtual bool onExecute(std::unique_ptr<MNN::NetT>& net) const override {
 | 
					    virtual bool onExecute(std::unique_ptr<MNN::NetT>& net) const override {
 | 
				
			||||||
| 
						 | 
					@ -62,7 +61,7 @@ public:
 | 
				
			||||||
                case MNN::OpType_TopKV2:
 | 
					                case MNN::OpType_TopKV2:
 | 
				
			||||||
                    types[op->outputIndexes[0]] = types[op->inputIndexes[0]];
 | 
					                    types[op->outputIndexes[0]] = types[op->inputIndexes[0]];
 | 
				
			||||||
                    if (op->outputIndexes.size() > 1) {
 | 
					                    if (op->outputIndexes.size() > 1) {
 | 
				
			||||||
                        types[op->outputIndexes[0]] = MNN::DataType_DT_INT32;
 | 
					                        types[op->outputIndexes[1]] = MNN::DataType_DT_INT32;
 | 
				
			||||||
                    }
 | 
					                    }
 | 
				
			||||||
                    break;
 | 
					                    break;
 | 
				
			||||||
                case MNN::OpType_ScatterNd:
 | 
					                case MNN::OpType_ScatterNd:
 | 
				
			||||||
| 
						 | 
					@ -73,11 +72,6 @@ public:
 | 
				
			||||||
                    types[op->outputIndexes[0]] = types[op->inputIndexes[2]];
 | 
					                    types[op->outputIndexes[0]] = types[op->inputIndexes[2]];
 | 
				
			||||||
                    break;
 | 
					                    break;
 | 
				
			||||||
                default:
 | 
					                default:
 | 
				
			||||||
                    if (op->inputIndexes.size() > 0) {
 | 
					 | 
				
			||||||
                        for (int i=0; i<op->outputIndexes.size(); ++i) {
 | 
					 | 
				
			||||||
                            types[op->outputIndexes[i]] = types[op->inputIndexes[0]];
 | 
					 | 
				
			||||||
                        }
 | 
					 | 
				
			||||||
                    }
 | 
					 | 
				
			||||||
                    break;
 | 
					                    break;
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue