Fix bug for RemoveInvalidCast for set topkv2 error

This commit is contained in:
xiaying 2021-08-03 15:45:19 +08:00
parent c4f9daa4e6
commit fd52fef0d8
1 changed files with 1 additions and 7 deletions

View File

@ -13,7 +13,6 @@
#include <algorithm> #include <algorithm>
#include "../PostTreatUtils.hpp" #include "../PostTreatUtils.hpp"
class RemoveInvalidCast : public PostConverter { class RemoveInvalidCast : public PostConverter {
public: public:
virtual bool onExecute(std::unique_ptr<MNN::NetT>& net) const override { virtual bool onExecute(std::unique_ptr<MNN::NetT>& net) const override {
@ -62,7 +61,7 @@ public:
case MNN::OpType_TopKV2: case MNN::OpType_TopKV2:
types[op->outputIndexes[0]] = types[op->inputIndexes[0]]; types[op->outputIndexes[0]] = types[op->inputIndexes[0]];
if (op->outputIndexes.size() > 1) { if (op->outputIndexes.size() > 1) {
types[op->outputIndexes[0]] = MNN::DataType_DT_INT32; types[op->outputIndexes[1]] = MNN::DataType_DT_INT32;
} }
break; break;
case MNN::OpType_ScatterNd: case MNN::OpType_ScatterNd:
@ -73,11 +72,6 @@ public:
types[op->outputIndexes[0]] = types[op->inputIndexes[2]]; types[op->outputIndexes[0]] = types[op->inputIndexes[2]];
break; break;
default: default:
if (op->inputIndexes.size() > 0) {
for (int i=0; i<op->outputIndexes.size(); ++i) {
types[op->outputIndexes[i]] = types[op->inputIndexes[0]];
}
}
break; break;
} }
} }