mirror of https://github.com/alibaba/MNN.git
Fix bug for RemoveInvalidCast for set topkv2 error
This commit is contained in:
parent
c4f9daa4e6
commit
fd52fef0d8
|
|
@ -13,7 +13,6 @@
|
|||
#include <algorithm>
|
||||
#include "../PostTreatUtils.hpp"
|
||||
|
||||
|
||||
class RemoveInvalidCast : public PostConverter {
|
||||
public:
|
||||
virtual bool onExecute(std::unique_ptr<MNN::NetT>& net) const override {
|
||||
|
|
@ -62,7 +61,7 @@ public:
|
|||
case MNN::OpType_TopKV2:
|
||||
types[op->outputIndexes[0]] = types[op->inputIndexes[0]];
|
||||
if (op->outputIndexes.size() > 1) {
|
||||
types[op->outputIndexes[0]] = MNN::DataType_DT_INT32;
|
||||
types[op->outputIndexes[1]] = MNN::DataType_DT_INT32;
|
||||
}
|
||||
break;
|
||||
case MNN::OpType_ScatterNd:
|
||||
|
|
@ -73,11 +72,6 @@ public:
|
|||
types[op->outputIndexes[0]] = types[op->inputIndexes[2]];
|
||||
break;
|
||||
default:
|
||||
if (op->inputIndexes.size() > 0) {
|
||||
for (int i=0; i<op->outputIndexes.size(); ++i) {
|
||||
types[op->outputIndexes[i]] = types[op->inputIndexes[0]];
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue