Fix bug for RemoveInvalidCast for set topkv2 error

This commit is contained in:
xiaying 2021-08-03 15:45:19 +08:00
parent c4f9daa4e6
commit fd52fef0d8
1 changed files with 1 additions and 7 deletions

View File

@ -13,7 +13,6 @@
#include <algorithm>
#include "../PostTreatUtils.hpp"
class RemoveInvalidCast : public PostConverter {
public:
virtual bool onExecute(std::unique_ptr<MNN::NetT>& net) const override {
@ -62,7 +61,7 @@ public:
case MNN::OpType_TopKV2:
types[op->outputIndexes[0]] = types[op->inputIndexes[0]];
if (op->outputIndexes.size() > 1) {
types[op->outputIndexes[0]] = MNN::DataType_DT_INT32;
types[op->outputIndexes[1]] = MNN::DataType_DT_INT32;
}
break;
case MNN::OpType_ScatterNd:
@ -73,11 +72,6 @@ public:
types[op->outputIndexes[0]] = types[op->inputIndexes[2]];
break;
default:
if (op->inputIndexes.size() > 0) {
for (int i=0; i<op->outputIndexes.size(); ++i) {
types[op->outputIndexes[i]] = types[op->inputIndexes[0]];
}
}
break;
}
}