Compare commits

...

4 Commits

Author SHA1 Message Date
Yexuan Wu d20c45eeca
Merge d0d879384a into dcb8cc7cf5 2025-06-13 15:25:25 +08:00
jxt1234 dcb8cc7cf5
Merge pull request #3613 from futz12/train-tools-cpp11-fix
修复train demo使用了cpp11以上标准的代码导致无法编译
2025-06-13 13:24:09 +08:00
futz12 89d0ed6df4 modified: tools/train/source/demo/ImageDatasetDemo.cpp
modified:   tools/train/source/demo/distillTrainQuant.cpp
2025-06-10 20:19:27 +08:00
futz12 d0d879384a modified: source/backend/cpu/x86_x64/avx/GemmInt8.cpp 2025-06-10 19:37:25 +08:00
3 changed files with 16 additions and 2 deletions

View File

@ -23,6 +23,20 @@ static inline void MNN__mm_storeu_si64(void* add, __m128i value) {
_mm_storeu_ps(temp, _mm_castsi128_ps(value));
::memcpy(add, temp, sizeof(int64_t));
}
#if defined(_MSC_VER) && !defined(_mm256_extract_epi64)
static inline uint64_t _mm256_extract_epi64(__m256i a, const int index)
{
typedef union {
__m256i v;
uint64_t i64[4];
} extractor;
extractor u;
u.v = a;
return u.i64[index];
}
#endif
} // namespace
#define POSTTREAT(N) \

View File

@ -64,7 +64,7 @@ public:
auto converImagesToFormat = CV::RGB;
int resizeHeight = 224;
int resizeWidth = 224;
std::vector<float> scales = {1/255.0, 1/255.0, 1/255.0};
std::vector<float> scales = {1/255.0f, 1/255.0f, 1/255.0f};
std::shared_ptr<ImageDataset::ImageConfig> config(ImageDataset::ImageConfig::create(converImagesToFormat, resizeHeight, resizeWidth, scales));
bool readAllImagesToMemory = false;
auto dataset = ImageDataset::create(pathToImages, pathToImageTxt, config.get(), readAllImagesToMemory);

View File

@ -84,7 +84,7 @@ void _train(std::shared_ptr<Module> origin, std::shared_ptr<Module> optmized, st
int resizeHeight = 224;
int resizeWidth = 224;
std::vector<float> means = {127.5, 127.5, 127.5};
std::vector<float> scales = {1/127.5, 1/127.5, 1/127.5};
std::vector<float> scales = {1/127.5f, 1/127.5f, 1/127.5f};
std::vector<float> cropFraction = {0.875, 0.875}; // center crop fraction for height and width
bool centerOrRandomCrop = false; // true for random crop
std::shared_ptr<ImageDataset::ImageConfig> datasetConfig(ImageDataset::ImageConfig::create(converImagesToFormat, resizeHeight, resizeWidth, scales, means, cropFraction, centerOrRandomCrop));